From e596834096d73646d91788f2a6c56e8da75c965b Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sat, 28 Jan 2023 20:29:06 -0500 Subject: [PATCH] Add "last access" to access tokens --- cmd/app_test.go | 3 +- server/config.go | 2 + server/server.go | 19 +++-- server/server_account.go | 31 ++++++--- server/server_account_test.go | 85 ++++++----------------- server/server_payments_test.go | 19 +++-- server/server_test.go | 4 +- server/types.go | 8 ++- server/visitor.go | 22 ++++-- user/manager.go | 111 ++++++++++++++++++++++-------- user/manager_test.go | 68 ++++++++++++++++-- user/types.go | 15 +++- web/public/static/langs/en.json | 2 + web/src/components/Account.js | 23 +++++-- web/src/components/Preferences.js | 9 ++- 15 files changed, 276 insertions(+), 145 deletions(-) diff --git a/cmd/app_test.go b/cmd/app_test.go index 9873dd0..0e8f4cd 100644 --- a/cmd/app_test.go +++ b/cmd/app_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "github.com/urfave/cli/v2" "heckel.io/ntfy/client" + "heckel.io/ntfy/log" "os" "strings" "testing" @@ -13,7 +14,7 @@ import ( // This only contains helpers so far func TestMain(m *testing.M) { - // log.SetOutput(io.Discard) + log.SetLevel(log.WarnLevel) os.Exit(m.Run()) } diff --git a/server/config.go b/server/config.go index 9ab2de9..a1d7c82 100644 --- a/server/config.go +++ b/server/config.go @@ -77,6 +77,7 @@ type Config struct { AuthStartupQueries string AuthDefault user.Permission AuthBcryptCost int + AuthStatsQueueWriterInterval time.Duration AttachmentCacheDir string AttachmentTotalSizeLimit int64 AttachmentFileSizeLimit int64 @@ -145,6 +146,7 @@ func NewConfig() *Config { AuthStartupQueries: "", AuthDefault: user.NewPermission(true, true), AuthBcryptCost: user.DefaultUserPasswordBcryptCost, + AuthStatsQueueWriterInterval: user.DefaultUserStatsQueueWriterInterval, AttachmentCacheDir: "", AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit, AttachmentFileSizeLimit: DefaultAttachmentFileSizeLimit, diff --git a/server/server.go b/server/server.go index 538109a..c2357d5 100644 --- a/server/server.go +++ b/server/server.go @@ -171,7 +171,7 @@ func New(conf *Config) (*Server, error) { } var userManager *user.Manager if conf.AuthFile != "" { - userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, user.DefaultUserStatsQueueWriterInterval) + userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, conf.AuthStatsQueueWriterInterval) if err != nil { return nil, err } @@ -598,7 +598,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes } u := v.User() if s.userManager != nil && u != nil && u.Tier != nil { - s.userManager.EnqueueStats(u.ID, v.Stats()) + go s.userManager.EnqueueStats(u.ID, v.Stats()) } s.mu.Lock() s.messages++ @@ -1620,7 +1620,7 @@ func (s *Server) authenticate(r *http.Request) (user *user.User, err error) { return nil, errHTTPUnauthorized } if strings.HasPrefix(value, "Bearer") { - return s.authenticateBearerAuth(value) + return s.authenticateBearerAuth(r, value) } return s.authenticateBasicAuth(r, value) } @@ -1634,9 +1634,18 @@ func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *use return s.userManager.Authenticate(username, password) } -func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) { +func (s *Server) authenticateBearerAuth(r *http.Request, value string) (*user.User, error) { token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer")) - return s.userManager.AuthenticateToken(token) + u, err := s.userManager.AuthenticateToken(token) + if err != nil { + return nil, err + } + ip := extractIPAddress(r, s.config.BehindProxy) + go s.userManager.EnqueueTokenUpdate(token, &user.TokenUpdate{ + LastAccess: time.Now(), + LastOrigin: ip, + }) + return u, nil } func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor { diff --git a/server/server_account.go b/server/server_account.go index 5de8df9..0c5b25b 100644 --- a/server/server_account.go +++ b/server/server_account.go @@ -6,6 +6,7 @@ import ( "heckel.io/ntfy/user" "heckel.io/ntfy/util" "net/http" + "net/netip" "strings" "time" ) @@ -122,10 +123,16 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis if len(tokens) > 0 { response.Tokens = make([]*apiAccountTokenResponse, 0) for _, t := range tokens { + var lastOrigin string + if t.LastOrigin != netip.IPv4Unspecified() { + lastOrigin = t.LastOrigin.String() + } response.Tokens = append(response.Tokens, &apiAccountTokenResponse{ - Token: t.Value, - Label: t.Label, - Expires: t.Expires.Unix(), + Token: t.Value, + Label: t.Label, + LastAccess: t.LastAccess.Unix(), + LastOrigin: lastOrigin, + Expires: t.Expires.Unix(), }) } } @@ -192,14 +199,16 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request if req.Expires != nil { expires = time.Unix(*req.Expires, 0) } - token, err := s.userManager.CreateToken(v.User().ID, label, expires) + token, err := s.userManager.CreateToken(v.User().ID, label, expires, v.IP()) if err != nil { return err } response := &apiAccountTokenResponse{ - Token: token.Value, - Label: token.Label, - Expires: token.Expires.Unix(), + Token: token.Value, + Label: token.Label, + LastAccess: token.LastAccess.Unix(), + LastOrigin: token.LastOrigin.String(), + Expires: token.Expires.Unix(), } return s.writeJSON(w, response) } @@ -228,9 +237,11 @@ func (s *Server) handleAccountTokenUpdate(w http.ResponseWriter, r *http.Request return err } response := &apiAccountTokenResponse{ - Token: token.Value, - Label: token.Label, - Expires: token.Expires.Unix(), + Token: token.Value, + Label: token.Label, + LastAccess: token.LastAccess.Unix(), + LastOrigin: token.LastOrigin.String(), + Expires: token.Expires.Unix(), } return s.writeJSON(w, response) } diff --git a/server/server_account_test.go b/server/server_account_test.go index bd5bc2c..fa98772 100644 --- a/server/server_account_test.go +++ b/server/server_account_test.go @@ -7,6 +7,7 @@ import ( "heckel.io/ntfy/user" "heckel.io/ntfy/util" "io" + "net/netip" "strings" "testing" "time" @@ -28,6 +29,10 @@ func TestAccount_Signup_Success(t *testing.T) { token, _ := util.UnmarshalJSON[apiAccountTokenResponse](io.NopCloser(rr.Body)) require.NotEmpty(t, token.Token) require.True(t, time.Now().Add(71*time.Hour).Unix() < token.Expires) + require.True(t, strings.HasPrefix(token.Token, "tk_")) + require.Equal(t, "9.9.9.9", token.LastOrigin) + require.True(t, token.LastAccess > time.Now().Unix()-1) + require.True(t, token.LastAccess < time.Now().Unix()+1) rr = request(t, s, "GET", "/v1/account", "", map[string]string{ "Authorization": util.BearerAuth(token.Token), @@ -161,7 +166,7 @@ func TestAccount_ChangeSettings(t *testing.T) { require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) u, _ := s.userManager.User("phil") - token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0)) + token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified()) rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{ "Authorization": util.BasicAuth("phil", "phil"), @@ -558,7 +563,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { // Subscribe anonymously anonCh, userCh := make(chan bool), make(chan bool) go func() { - rr := request(t, s, "GET", "/mytopic/json", ``, nil) + rr := request(t, s, "GET", "/mytopic/json", ``, nil) // This blocks until it's killed! require.Equal(t, 200, rr.Code) messages := toMessages(t, rr.Body.String()) require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message! @@ -570,7 +575,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { // Subscribe with user go func() { - rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{ + rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{ // Blocks! "Authorization": util.BasicAuth("phil", "mypass"), }) require.Equal(t, 200, rr.Code) @@ -584,10 +589,10 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { }() // Publish message (before reservation) - time.Sleep(time.Second) // Wait for subscribers + time.Sleep(2 * time.Second) // Wait for subscribers rr = request(t, s, "POST", "/mytopic", "message before reservation", nil) require.Equal(t, 200, rr.Code) - time.Sleep(time.Second) // Wait for subscribers to receive message + time.Sleep(2 * time.Second) // Wait for subscribers to receive message // Reserve a topic rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{ @@ -596,7 +601,11 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { require.Equal(t, 200, rr.Code) // Everyone but phil should be killed - <-anonCh + select { + case <-anonCh: + case <-time.After(5 * time.Second): + t.Fatal("Waiting for anonymous subscription to be killed failed") + } // Publish a message rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{ @@ -606,62 +615,10 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) { // Kill user Go routine s.topics["mytopic"].CancelSubscribers("") - <-userCh -} - -func TestAccount_Tier_Create(t *testing.T) { - conf := newTestConfigWithAuthFile(t) - s := newTestServer(t, conf) - - // Create tier and user - require.Nil(t, s.userManager.CreateTier(&user.Tier{ - Code: "pro", - Name: "Pro", - MessageLimit: 123, - MessageExpiryDuration: 86400 * time.Second, - EmailLimit: 32, - ReservationLimit: 2, - AttachmentFileSizeLimit: 1231231, - AttachmentTotalSizeLimit: 123123, - AttachmentExpiryDuration: 10800 * time.Second, - AttachmentBandwidthLimit: 21474836480, - })) - require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser)) - require.Nil(t, s.userManager.ChangeTier("phil", "pro")) - - ti, err := s.userManager.Tier("pro") - require.Nil(t, err) - - u, err := s.userManager.User("phil") - require.Nil(t, err) - - // These are populated by different SQL queries - require.Equal(t, ti, u.Tier) - - // Fields - require.True(t, strings.HasPrefix(ti.ID, "ti_")) - require.Equal(t, "pro", ti.Code) - require.Equal(t, "Pro", ti.Name) - require.Equal(t, int64(123), ti.MessageLimit) - require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration) - require.Equal(t, int64(32), ti.EmailLimit) - require.Equal(t, int64(2), ti.ReservationLimit) - require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit) - require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) - require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) - require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit) -} - -func TestAccount_Tier_Create_With_ID(t *testing.T) { - conf := newTestConfigWithAuthFile(t) - s := newTestServer(t, conf) - - require.Nil(t, s.userManager.CreateTier(&user.Tier{ - ID: "ti_123", - Code: "pro", - })) - - ti, err := s.userManager.Tier("pro") - require.Nil(t, err) - require.Equal(t, "ti_123", ti.ID) + + select { + case <-userCh: + case <-time.After(5 * time.Second): + t.Fatal("Waiting for user subscription to be killed failed") + } } diff --git a/server/server_payments_test.go b/server/server_payments_test.go index 1576ab0..7206a65 100644 --- a/server/server_payments_test.go +++ b/server/server_payments_test.go @@ -258,11 +258,6 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes c.StripeWebhookKey = "webhook key" c.VisitorRequestLimitBurst = 5 c.VisitorRequestLimitReplenish = time.Hour - c.CacheStartupQueries = ` -pragma journal_mode = WAL; -pragma synchronous = normal; -pragma temp_store = memory; -` c.CacheBatchSize = 500 c.CacheBatchTimeout = time.Second s := newTestServer(t, c) @@ -324,6 +319,18 @@ pragma temp_store = memory; }) require.Equal(t, 429, rr.Code) + // Verify some "before-stats" + u, err = s.userManager.User("phil") + require.Nil(t, err) + require.Nil(t, u.Tier) + require.Equal(t, "", u.Billing.StripeCustomerID) + require.Equal(t, "", u.Billing.StripeSubscriptionID) + require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus) + require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix()) + require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) + require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users! + require.Equal(t, int64(0), u.Stats.Emails) + // Simulate Stripe success return URL call (no user context) rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil) require.Equal(t, 303, rr.Code) @@ -337,6 +344,8 @@ pragma temp_store = memory; require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus) require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix()) require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix()) + require.Equal(t, int64(0), u.Stats.Messages) + require.Equal(t, int64(0), u.Stats.Emails) // Now for the fun part: Verify that new rate limits are immediately applied // This only tests the request limiter, which kicks in before the message limiter. diff --git a/server/server_test.go b/server/server_test.go index 3ad42ce..3c49114 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -892,10 +892,8 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) { // if the visitor is unknown c := newTestConfigWithAuthFile(t) + c.AuthStatsQueueWriterInterval = 100 * time.Millisecond s := newTestServer(t, c) - var err error - s.userManager, err = user.NewManager(c.AuthFile, c.AuthStartupQueries, c.AuthDefault, c.AuthBcryptCost, 100*time.Millisecond) - require.Nil(t, err) // Create user, and update it with some message and email stats require.Nil(t, s.userManager.CreateTier(&user.Tier{ diff --git a/server/types.go b/server/types.go index 981d99f..07aee52 100644 --- a/server/types.go +++ b/server/types.go @@ -247,9 +247,11 @@ type apiAccountTokenUpdateRequest struct { } type apiAccountTokenResponse struct { - Token string `json:"token"` - Label string `json:"label,omitempty"` - Expires int64 `json:"expires,omitempty"` // Unix timestamp + Token string `json:"token"` + Label string `json:"label,omitempty"` + LastAccess int64 `json:"last_access,omitempty"` + LastOrigin string `json:"last_origin,omitempty"` + Expires int64 `json:"expires,omitempty"` // Unix timestamp } type apiAccountTier struct { diff --git a/server/visitor.go b/server/visitor.go index d4d2ea1..84816fd 100644 --- a/server/visitor.go +++ b/server/visitor.go @@ -131,7 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana bandwidthLimiter: nil, // Set in resetLimiters accountLimiter: nil, // Set in resetLimiters, may be nil } - v.resetLimitersNoLock(messages, emails) + v.resetLimitersNoLock(messages, emails, false) return v } @@ -254,6 +254,13 @@ func (v *visitor) User() *user.User { return v.user // May be nil } +// IP returns the visitor IP address +func (v *visitor) IP() netip.Addr { + v.mu.Lock() + defer v.mu.Unlock() + return v.ip +} + // Authenticated returns true if a user successfully authenticated func (v *visitor) Authenticated() bool { v.mu.Lock() @@ -268,7 +275,7 @@ func (v *visitor) SetUser(u *user.User) { shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver v.user = u if shouldResetLimiters { - v.resetLimitersNoLock(0, 0) + v.resetLimitersNoLock(0, 0, true) } } @@ -283,7 +290,7 @@ func (v *visitor) MaybeUserID() string { return "" } -func (v *visitor) resetLimitersNoLock(messages, emails int64) { +func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) { log.Debug("%s Resetting limiters for visitor", v.stringNoLock()) limits := v.limitsNoLock() v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst) @@ -295,6 +302,13 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64) { } else { v.accountLimiter = nil // Users cannot create accounts when logged in } + /* + if enqueueUpdate && v.user != nil { + go v.userManager.EnqueueStats(v.user.ID, &user.Stats{ + Messages: messages, + Emails: emails, + }) + }*/ } func (v *visitor) Limits() *visitorLimits { @@ -361,7 +375,7 @@ func (v *visitor) Info() (*visitorInfo, error) { if u != nil { attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(u.ID) } else { - attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String()) + attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.IP().String()) } if err != nil { return nil, err diff --git a/user/manager.go b/user/manager.go index 79db16d..0ed5794 100644 --- a/user/manager.go +++ b/user/manager.go @@ -10,6 +10,7 @@ import ( "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/log" "heckel.io/ntfy/util" + "net/netip" "strings" "sync" "time" @@ -95,6 +96,8 @@ const ( user_id TEXT NOT NULL, token TEXT NOT NULL, label TEXT NOT NULL, + last_access INT NOT NULL, + last_origin TEXT NOT NULL, expires INT NOT NULL, PRIMARY KEY (user_id, token), FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE @@ -127,9 +130,9 @@ const ( selectUserByTokenQuery = ` SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id FROM user u - JOIN user_token t on u.id = t.user_id + JOIN user_token tk on u.id = tk.user_id LEFT JOIN tier t on t.id = u.tier_id - WHERE t.token = ? AND (t.expires = 0 OR t.expires >= ?) + WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?) ` selectUserByStripeCustomerIDQuery = ` SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id @@ -218,16 +221,17 @@ const ( AND topic = ? ` - selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` - selectTokensQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ?` - selectTokenQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ? AND token = ?` - insertTokenQuery = `INSERT INTO user_token (user_id, token, label, expires) VALUES (?, ?, ?, ?)` - updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` - updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` - deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` - deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?` - deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` - deleteExcessTokensQuery = ` + selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?` + selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?` + selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?` + insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)` + updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?` + updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?` + updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?` + deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?` + deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?` + deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?` + deleteExcessTokensQuery = ` DELETE FROM user_token WHERE (user_id, token) NOT IN ( SELECT user_id, token @@ -297,16 +301,17 @@ const ( // in a SQLite database. type Manager struct { db *sql.DB - defaultAccess Permission // Default permission if no ACL matches - statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats) - bcryptCost int // Makes testing easier + defaultAccess Permission // Default permission if no ACL matches + statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats) + tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate) + bcryptCost int // Makes testing easier mu sync.Mutex } var _ Auther = (*Manager)(nil) // NewManager creates a new Manager instance -func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) (*Manager, error) { +func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) { db, err := sql.Open("sqlite3", filename) if err != nil { return nil, err @@ -321,9 +326,10 @@ func NewManager(filename, startupQueries string, defaultAccess Permission, bcryp db: db, defaultAccess: defaultAccess, statsQueue: make(map[string]*Stats), + tokenQueue: make(map[string]*TokenUpdate), bcryptCost: bcryptCost, } - go manager.userStatsQueueWriter(statsWriterInterval) + go manager.asyncQueueWriter(queueWriterInterval) return manager, nil } @@ -367,14 +373,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) { // CreateToken generates a random token for the given user and returns it. The token expires // after a fixed duration unless ChangeToken is called. This function also prunes tokens for the // given user, if there are too many of them. -func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token, error) { +func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) { token := util.RandomStringPrefix(tokenPrefix, tokenLength) tx, err := a.db.Begin() if err != nil { return nil, err } defer tx.Rollback() - if _, err := tx.Exec(insertTokenQuery, userID, token, label, expires.Unix()); err != nil { + access := time.Now() + if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil { return nil, err } rows, err := tx.Query(selectTokenCountQuery, userID) @@ -400,9 +407,11 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token, return nil, err } return &Token{ - Value: token, - Label: label, - Expires: expires, + Value: token, + Label: label, + LastAccess: access, + LastOrigin: origin, + Expires: expires, }, nil } @@ -437,20 +446,26 @@ func (a *Manager) Token(userID, token string) (*Token, error) { } func (a *Manager) readToken(rows *sql.Rows) (*Token, error) { - var token, label string - var expires int64 + var token, label, lastOrigin string + var lastAccess, expires int64 if !rows.Next() { return nil, ErrTokenNotFound } - if err := rows.Scan(&token, &label, &expires); err != nil { + if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil { return nil, err } else if err := rows.Err(); err != nil { return nil, err } + lastOriginIP, err := netip.ParseAddr(lastOrigin) + if err != nil { + lastOriginIP = netip.IPv4Unspecified() + } return &Token{ - Value: token, - Label: label, - Expires: time.Unix(expires, 0), + Value: token, + Label: label, + LastAccess: time.Unix(lastAccess, 0), + LastOrigin: lastOriginIP, + Expires: time.Unix(expires, 0), }, nil } @@ -521,7 +536,7 @@ func (a *Manager) ChangeSettings(user *User) error { // ResetStats resets all user stats in the user database. This touches all users. func (a *Manager) ResetStats() error { - a.mu.Lock() + a.mu.Lock() // Includes database query to avoid races! defer a.mu.Unlock() if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil { return err @@ -538,12 +553,23 @@ func (a *Manager) EnqueueStats(userID string, stats *Stats) { a.statsQueue[userID] = stats } -func (a *Manager) userStatsQueueWriter(interval time.Duration) { +// EnqueueTokenUpdate adds the token update to a queue which writes out token access times +// in batches at a regular interval +func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) { + a.mu.Lock() + defer a.mu.Unlock() + a.tokenQueue[tokenID] = update +} + +func (a *Manager) asyncQueueWriter(interval time.Duration) { ticker := time.NewTicker(interval) for range ticker.C { if err := a.writeUserStatsQueue(); err != nil { log.Warn("User Manager: Writing user stats queue failed: %s", err.Error()) } + if err := a.writeTokenUpdateQueue(); err != nil { + log.Warn("User Manager: Writing token update queue failed: %s", err.Error()) + } } } @@ -572,6 +598,31 @@ func (a *Manager) writeUserStatsQueue() error { return tx.Commit() } +func (a *Manager) writeTokenUpdateQueue() error { + a.mu.Lock() + if len(a.tokenQueue) == 0 { + a.mu.Unlock() + log.Trace("User Manager: No token updates to commit") + return nil + } + tokenQueue := a.tokenQueue + a.tokenQueue = make(map[string]*TokenUpdate) + a.mu.Unlock() + tx, err := a.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + log.Debug("User Manager: Writing token update queue for %d token(s)", len(tokenQueue)) + for tokenID, update := range tokenQueue { + log.Trace("User Manager: Updating token %s with last access time %v", tokenID, update.LastAccess.Unix()) + if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil { + return err + } + } + return tx.Commit() +} + // Authorize returns nil if the given user has access to the given topic using the desired // permission. The user param may be nil to signal an anonymous user. func (a *Manager) Authorize(user *User, topic string, perm Permission) error { diff --git a/user/manager_test.go b/user/manager_test.go index 4c5e368..1a12772 100644 --- a/user/manager_test.go +++ b/user/manager_test.go @@ -5,6 +5,7 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" "heckel.io/ntfy/util" + "net/netip" "path/filepath" "strings" "testing" @@ -139,7 +140,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) { require.Nil(t, err) require.False(t, u.Deleted) - token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour)) + token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified()) require.Nil(t, err) u, err = a.Authenticate("user", "pass") @@ -397,7 +398,7 @@ func TestManager_Token_Valid(t *testing.T) { require.Nil(t, err) // Create token for user - token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour)) + token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) require.Nil(t, err) require.NotEmpty(t, token.Value) require.Equal(t, "some label", token.Label) @@ -441,12 +442,12 @@ func TestManager_Token_Expire(t *testing.T) { require.Nil(t, err) // Create tokens for user - token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour)) + token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) require.Nil(t, err) require.NotEmpty(t, token1.Value) require.True(t, time.Now().Add(71*time.Hour).Unix() < token1.Expires.Unix()) - token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour)) + token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) require.Nil(t, err) require.NotEmpty(t, token2.Value) require.NotEqual(t, token1.Value, token2.Value) @@ -493,7 +494,7 @@ func TestManager_Token_Extend(t *testing.T) { require.Equal(t, errNoTokenProvided, err) // Create token for user - token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour)) + token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) require.Nil(t, err) require.NotEmpty(t, token.Value) @@ -520,7 +521,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) { baseTime := time.Now().Add(24 * time.Hour) tokens := make([]string, 0) for i := 0; i < 12; i++ { - token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour)) + token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified()) require.Nil(t, err) require.NotEmpty(t, token.Value) tokens = append(tokens, token.Value) @@ -624,6 +625,61 @@ func TestManager_ChangeSettings(t *testing.T) { require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName) } +func TestManager_Tier_Create(t *testing.T) { + a := newTestManager(t, PermissionDenyAll) + + // Create tier and user + require.Nil(t, a.CreateTier(&Tier{ + Code: "pro", + Name: "Pro", + MessageLimit: 123, + MessageExpiryDuration: 86400 * time.Second, + EmailLimit: 32, + ReservationLimit: 2, + AttachmentFileSizeLimit: 1231231, + AttachmentTotalSizeLimit: 123123, + AttachmentExpiryDuration: 10800 * time.Second, + AttachmentBandwidthLimit: 21474836480, + })) + require.Nil(t, a.AddUser("phil", "phil", RoleUser)) + require.Nil(t, a.ChangeTier("phil", "pro")) + + ti, err := a.Tier("pro") + require.Nil(t, err) + + u, err := a.User("phil") + require.Nil(t, err) + + // These are populated by different SQL queries + require.Equal(t, ti, u.Tier) + + // Fields + require.True(t, strings.HasPrefix(ti.ID, "ti_")) + require.Equal(t, "pro", ti.Code) + require.Equal(t, "Pro", ti.Name) + require.Equal(t, int64(123), ti.MessageLimit) + require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration) + require.Equal(t, int64(32), ti.EmailLimit) + require.Equal(t, int64(2), ti.ReservationLimit) + require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit) + require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit) + require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration) + require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit) +} + +func TestAccount_Tier_Create_With_ID(t *testing.T) { + a := newTestManager(t, PermissionDenyAll) + + require.Nil(t, a.CreateTier(&Tier{ + ID: "ti_123", + Code: "pro", + })) + + ti, err := a.Tier("pro") + require.Nil(t, err) + require.Equal(t, "ti_123", ti.ID) +} + func TestSqliteCache_Migration_From1(t *testing.T) { filename := filepath.Join(t.TempDir(), "user.db") db, err := sql.Open("sqlite3", filename) diff --git a/user/types.go b/user/types.go index d6c291b..bed4207 100644 --- a/user/types.go +++ b/user/types.go @@ -4,6 +4,7 @@ package user import ( "errors" "github.com/stripe/stripe-go/v74" + "net/netip" "regexp" "time" ) @@ -46,9 +47,17 @@ type Auther interface { // Token represents a user token, including expiry date type Token struct { - Value string - Label string - Expires time.Time + Value string + Label string + LastAccess time.Time + LastOrigin netip.Addr + Expires time.Time +} + +// TokenUpdate holds information about the last access time and origin IP address of a token +type TokenUpdate struct { + LastAccess time.Time + LastOrigin netip.Addr } // Prefs represents a user's configuration settings diff --git a/web/public/static/langs/en.json b/web/public/static/langs/en.json index 722652a..652282b 100644 --- a/web/public/static/langs/en.json +++ b/web/public/static/langs/en.json @@ -226,6 +226,7 @@ "account_tokens_description": "Use access tokens when publishing and subscribing via the ntfy API, so you don't have to send your account credentials. Check out the documentation to learn more.", "account_tokens_table_token_header": "Token", "account_tokens_table_label_header": "Label", + "account_tokens_table_last_access_header": "Last access", "account_tokens_table_expires_header": "Expires", "account_tokens_table_never_expires": "Never expires", "account_tokens_table_current_session": "Current browser session", @@ -233,6 +234,7 @@ "account_tokens_table_copied_to_clipboard": "Access token copied", "account_tokens_table_cannot_delete_or_edit": "Cannot edit or delete current session token", "account_tokens_table_create_token_button": "Create access token", + "account_tokens_table_last_origin_tooltip": "From IP address {{ip}}, click to lookup", "account_tokens_dialog_title_create": "Create access token", "account_tokens_dialog_title_edit": "Edit access token", "account_tokens_dialog_title_delete": "Delete access token", diff --git a/web/src/components/Account.js b/web/src/components/Account.js index 7552b5b..ffefe13 100644 --- a/web/src/components/Account.js +++ b/web/src/components/Account.js @@ -27,7 +27,7 @@ import DialogContent from "@mui/material/DialogContent"; import TextField from "@mui/material/TextField"; import routes from "./routes"; import IconButton from "@mui/material/IconButton"; -import {formatBytes, formatShortDate, formatShortDateTime, truncateString, validUrl} from "../app/utils"; +import {formatBytes, formatShortDate, formatShortDateTime, openUrl, truncateString, validUrl} from "../app/utils"; import accountApi, {IncorrectPasswordError, UnauthorizedError} from "../app/AccountApi"; import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined'; import {Pref, PrefGroup} from "./Pref"; @@ -43,7 +43,7 @@ import userManager from "../app/UserManager"; import {Paragraph} from "./styles"; import CloseIcon from "@mui/icons-material/Close"; import DialogActions from "@mui/material/DialogActions"; -import {ContentCopy} from "@mui/icons-material"; +import {ContentCopy, Public} from "@mui/icons-material"; import MenuItem from "@mui/material/MenuItem"; import ListItemIcon from "@mui/material/ListItemIcon"; import {PermissionDenyAll, PermissionRead, PermissionReadWrite, PermissionWrite} from "./ReserveIcons"; @@ -506,6 +506,7 @@ const TokensTable = (props) => { {t("account_tokens_table_token_header")} {t("account_tokens_table_label_header")} {t("account_tokens_table_expires_header")} + {t("account_tokens_table_last_access_header")} @@ -513,11 +514,11 @@ const TokensTable = (props) => { {tokens.map(token => ( - + - {token.token.slice(0, 20)} + {token.token.slice(0, 12)} ... handleCopy(token.token)}> @@ -531,7 +532,17 @@ const TokensTable = (props) => { {token.expires ? formatShortDateTime(token.expires) : {t("account_tokens_table_never_expires")}} - + +
+ {formatShortDateTime(token.last_access)} + + openUrl(`https://whatismyipaddress.com/ip/${token.last_origin}`)}> + + + +
+
+ {token.token !== session.token() && <> handleEditClick(token)} aria-label={t("account_tokens_dialog_title_edit")}> diff --git a/web/src/components/Preferences.js b/web/src/components/Preferences.js index 97e001f..ea14fcd 100644 --- a/web/src/components/Preferences.js +++ b/web/src/components/Preferences.js @@ -300,10 +300,9 @@ const UserTable = (props) => { key={user.baseUrl} sx={{'&:last-child td, &:last-child th': {border: 0}}} > - {user.username} + {user.username} {user.baseUrl} - + {(!session.exists() || user.baseUrl !== config.base_url) && <> handleEditClick(user)} aria-label={t("prefs_users_edit_button")}> @@ -597,7 +596,7 @@ const ReservationsTable = (props) => { {props.reservations.map(reservation => ( {reservation.topic} @@ -628,7 +627,7 @@ const ReservationsTable = (props) => { } - + {!localSubscriptions[reservation.topic] && } label="Not subscribed" color="primary" variant="outlined"/> }