Add "last access" to access tokens
This commit is contained in:
parent
000bf27c87
commit
e596834096
15 changed files with 276 additions and 145 deletions
|
@ -77,6 +77,7 @@ type Config struct {
|
|||
AuthStartupQueries string
|
||||
AuthDefault user.Permission
|
||||
AuthBcryptCost int
|
||||
AuthStatsQueueWriterInterval time.Duration
|
||||
AttachmentCacheDir string
|
||||
AttachmentTotalSizeLimit int64
|
||||
AttachmentFileSizeLimit int64
|
||||
|
@ -145,6 +146,7 @@ func NewConfig() *Config {
|
|||
AuthStartupQueries: "",
|
||||
AuthDefault: user.NewPermission(true, true),
|
||||
AuthBcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||
AuthStatsQueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||
AttachmentCacheDir: "",
|
||||
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
|
||||
AttachmentFileSizeLimit: DefaultAttachmentFileSizeLimit,
|
||||
|
|
|
@ -171,7 +171,7 @@ func New(conf *Config) (*Server, error) {
|
|||
}
|
||||
var userManager *user.Manager
|
||||
if conf.AuthFile != "" {
|
||||
userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, user.DefaultUserStatsQueueWriterInterval)
|
||||
userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, conf.AuthStatsQueueWriterInterval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -598,7 +598,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
|||
}
|
||||
u := v.User()
|
||||
if s.userManager != nil && u != nil && u.Tier != nil {
|
||||
s.userManager.EnqueueStats(u.ID, v.Stats())
|
||||
go s.userManager.EnqueueStats(u.ID, v.Stats())
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.messages++
|
||||
|
@ -1620,7 +1620,7 @@ func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
|
|||
return nil, errHTTPUnauthorized
|
||||
}
|
||||
if strings.HasPrefix(value, "Bearer") {
|
||||
return s.authenticateBearerAuth(value)
|
||||
return s.authenticateBearerAuth(r, value)
|
||||
}
|
||||
return s.authenticateBasicAuth(r, value)
|
||||
}
|
||||
|
@ -1634,9 +1634,18 @@ func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *use
|
|||
return s.userManager.Authenticate(username, password)
|
||||
}
|
||||
|
||||
func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
|
||||
func (s *Server) authenticateBearerAuth(r *http.Request, value string) (*user.User, error) {
|
||||
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
|
||||
return s.userManager.AuthenticateToken(token)
|
||||
u, err := s.userManager.AuthenticateToken(token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||
go s.userManager.EnqueueTokenUpdate(token, &user.TokenUpdate{
|
||||
LastAccess: time.Now(),
|
||||
LastOrigin: ip,
|
||||
})
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
@ -122,10 +123,16 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
|||
if len(tokens) > 0 {
|
||||
response.Tokens = make([]*apiAccountTokenResponse, 0)
|
||||
for _, t := range tokens {
|
||||
var lastOrigin string
|
||||
if t.LastOrigin != netip.IPv4Unspecified() {
|
||||
lastOrigin = t.LastOrigin.String()
|
||||
}
|
||||
response.Tokens = append(response.Tokens, &apiAccountTokenResponse{
|
||||
Token: t.Value,
|
||||
Label: t.Label,
|
||||
Expires: t.Expires.Unix(),
|
||||
Token: t.Value,
|
||||
Label: t.Label,
|
||||
LastAccess: t.LastAccess.Unix(),
|
||||
LastOrigin: lastOrigin,
|
||||
Expires: t.Expires.Unix(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -192,14 +199,16 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request
|
|||
if req.Expires != nil {
|
||||
expires = time.Unix(*req.Expires, 0)
|
||||
}
|
||||
token, err := s.userManager.CreateToken(v.User().ID, label, expires)
|
||||
token, err := s.userManager.CreateToken(v.User().ID, label, expires, v.IP())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response := &apiAccountTokenResponse{
|
||||
Token: token.Value,
|
||||
Label: token.Label,
|
||||
Expires: token.Expires.Unix(),
|
||||
Token: token.Value,
|
||||
Label: token.Label,
|
||||
LastAccess: token.LastAccess.Unix(),
|
||||
LastOrigin: token.LastOrigin.String(),
|
||||
Expires: token.Expires.Unix(),
|
||||
}
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
@ -228,9 +237,11 @@ func (s *Server) handleAccountTokenUpdate(w http.ResponseWriter, r *http.Request
|
|||
return err
|
||||
}
|
||||
response := &apiAccountTokenResponse{
|
||||
Token: token.Value,
|
||||
Label: token.Label,
|
||||
Expires: token.Expires.Unix(),
|
||||
Token: token.Value,
|
||||
Label: token.Label,
|
||||
LastAccess: token.LastAccess.Unix(),
|
||||
LastOrigin: token.LastOrigin.String(),
|
||||
Expires: token.Expires.Unix(),
|
||||
}
|
||||
return s.writeJSON(w, response)
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
"io"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -28,6 +29,10 @@ func TestAccount_Signup_Success(t *testing.T) {
|
|||
token, _ := util.UnmarshalJSON[apiAccountTokenResponse](io.NopCloser(rr.Body))
|
||||
require.NotEmpty(t, token.Token)
|
||||
require.True(t, time.Now().Add(71*time.Hour).Unix() < token.Expires)
|
||||
require.True(t, strings.HasPrefix(token.Token, "tk_"))
|
||||
require.Equal(t, "9.9.9.9", token.LastOrigin)
|
||||
require.True(t, token.LastAccess > time.Now().Unix()-1)
|
||||
require.True(t, token.LastAccess < time.Now().Unix()+1)
|
||||
|
||||
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||
"Authorization": util.BearerAuth(token.Token),
|
||||
|
@ -161,7 +166,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
|
|||
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
u, _ := s.userManager.User("phil")
|
||||
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0))
|
||||
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
|
||||
|
||||
rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "phil"),
|
||||
|
@ -558,7 +563,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
|||
// Subscribe anonymously
|
||||
anonCh, userCh := make(chan bool), make(chan bool)
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/mytopic/json", ``, nil)
|
||||
rr := request(t, s, "GET", "/mytopic/json", ``, nil) // This blocks until it's killed!
|
||||
require.Equal(t, 200, rr.Code)
|
||||
messages := toMessages(t, rr.Body.String())
|
||||
require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message!
|
||||
|
@ -570,7 +575,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
|||
|
||||
// Subscribe with user
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{
|
||||
rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{ // Blocks!
|
||||
"Authorization": util.BasicAuth("phil", "mypass"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
@ -584,10 +589,10 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
|||
}()
|
||||
|
||||
// Publish message (before reservation)
|
||||
time.Sleep(time.Second) // Wait for subscribers
|
||||
time.Sleep(2 * time.Second) // Wait for subscribers
|
||||
rr = request(t, s, "POST", "/mytopic", "message before reservation", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
time.Sleep(time.Second) // Wait for subscribers to receive message
|
||||
time.Sleep(2 * time.Second) // Wait for subscribers to receive message
|
||||
|
||||
// Reserve a topic
|
||||
rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{
|
||||
|
@ -596,7 +601,11 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
|||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Everyone but phil should be killed
|
||||
<-anonCh
|
||||
select {
|
||||
case <-anonCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Waiting for anonymous subscription to be killed failed")
|
||||
}
|
||||
|
||||
// Publish a message
|
||||
rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{
|
||||
|
@ -606,62 +615,10 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
|||
|
||||
// Kill user Go routine
|
||||
s.topics["mytopic"].CancelSubscribers("<invalid>")
|
||||
<-userCh
|
||||
}
|
||||
|
||||
func TestAccount_Tier_Create(t *testing.T) {
|
||||
conf := newTestConfigWithAuthFile(t)
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
// Create tier and user
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
Name: "Pro",
|
||||
MessageLimit: 123,
|
||||
MessageExpiryDuration: 86400 * time.Second,
|
||||
EmailLimit: 32,
|
||||
ReservationLimit: 2,
|
||||
AttachmentFileSizeLimit: 1231231,
|
||||
AttachmentTotalSizeLimit: 123123,
|
||||
AttachmentExpiryDuration: 10800 * time.Second,
|
||||
AttachmentBandwidthLimit: 21474836480,
|
||||
}))
|
||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
ti, err := s.userManager.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
|
||||
u, err := s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
|
||||
// These are populated by different SQL queries
|
||||
require.Equal(t, ti, u.Tier)
|
||||
|
||||
// Fields
|
||||
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
||||
require.Equal(t, "pro", ti.Code)
|
||||
require.Equal(t, "Pro", ti.Name)
|
||||
require.Equal(t, int64(123), ti.MessageLimit)
|
||||
require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
|
||||
require.Equal(t, int64(32), ti.EmailLimit)
|
||||
require.Equal(t, int64(2), ti.ReservationLimit)
|
||||
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
||||
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
||||
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
||||
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
|
||||
}
|
||||
|
||||
func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
||||
conf := newTestConfigWithAuthFile(t)
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
ID: "ti_123",
|
||||
Code: "pro",
|
||||
}))
|
||||
|
||||
ti, err := s.userManager.Tier("pro")
|
||||
require.Nil(t, err)
|
||||
require.Equal(t, "ti_123", ti.ID)
|
||||
|
||||
select {
|
||||
case <-userCh:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Waiting for user subscription to be killed failed")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -258,11 +258,6 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
|||
c.StripeWebhookKey = "webhook key"
|
||||
c.VisitorRequestLimitBurst = 5
|
||||
c.VisitorRequestLimitReplenish = time.Hour
|
||||
c.CacheStartupQueries = `
|
||||
pragma journal_mode = WAL;
|
||||
pragma synchronous = normal;
|
||||
pragma temp_store = memory;
|
||||
`
|
||||
c.CacheBatchSize = 500
|
||||
c.CacheBatchTimeout = time.Second
|
||||
s := newTestServer(t, c)
|
||||
|
@ -324,6 +319,18 @@ pragma temp_store = memory;
|
|||
})
|
||||
require.Equal(t, 429, rr.Code)
|
||||
|
||||
// Verify some "before-stats"
|
||||
u, err = s.userManager.User("phil")
|
||||
require.Nil(t, err)
|
||||
require.Nil(t, u.Tier)
|
||||
require.Equal(t, "", u.Billing.StripeCustomerID)
|
||||
require.Equal(t, "", u.Billing.StripeSubscriptionID)
|
||||
require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
|
||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||
require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
|
||||
require.Equal(t, int64(0), u.Stats.Emails)
|
||||
|
||||
// Simulate Stripe success return URL call (no user context)
|
||||
rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
|
||||
require.Equal(t, 303, rr.Code)
|
||||
|
@ -337,6 +344,8 @@ pragma temp_store = memory;
|
|||
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
|
||||
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||
require.Equal(t, int64(0), u.Stats.Messages)
|
||||
require.Equal(t, int64(0), u.Stats.Emails)
|
||||
|
||||
// Now for the fun part: Verify that new rate limits are immediately applied
|
||||
// This only tests the request limiter, which kicks in before the message limiter.
|
||||
|
|
|
@ -892,10 +892,8 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) {
|
|||
// if the visitor is unknown
|
||||
|
||||
c := newTestConfigWithAuthFile(t)
|
||||
c.AuthStatsQueueWriterInterval = 100 * time.Millisecond
|
||||
s := newTestServer(t, c)
|
||||
var err error
|
||||
s.userManager, err = user.NewManager(c.AuthFile, c.AuthStartupQueries, c.AuthDefault, c.AuthBcryptCost, 100*time.Millisecond)
|
||||
require.Nil(t, err)
|
||||
|
||||
// Create user, and update it with some message and email stats
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
|
|
|
@ -247,9 +247,11 @@ type apiAccountTokenUpdateRequest struct {
|
|||
}
|
||||
|
||||
type apiAccountTokenResponse struct {
|
||||
Token string `json:"token"`
|
||||
Label string `json:"label,omitempty"`
|
||||
Expires int64 `json:"expires,omitempty"` // Unix timestamp
|
||||
Token string `json:"token"`
|
||||
Label string `json:"label,omitempty"`
|
||||
LastAccess int64 `json:"last_access,omitempty"`
|
||||
LastOrigin string `json:"last_origin,omitempty"`
|
||||
Expires int64 `json:"expires,omitempty"` // Unix timestamp
|
||||
}
|
||||
|
||||
type apiAccountTier struct {
|
||||
|
|
|
@ -131,7 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
|||
bandwidthLimiter: nil, // Set in resetLimiters
|
||||
accountLimiter: nil, // Set in resetLimiters, may be nil
|
||||
}
|
||||
v.resetLimitersNoLock(messages, emails)
|
||||
v.resetLimitersNoLock(messages, emails, false)
|
||||
return v
|
||||
}
|
||||
|
||||
|
@ -254,6 +254,13 @@ func (v *visitor) User() *user.User {
|
|||
return v.user // May be nil
|
||||
}
|
||||
|
||||
// IP returns the visitor IP address
|
||||
func (v *visitor) IP() netip.Addr {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
return v.ip
|
||||
}
|
||||
|
||||
// Authenticated returns true if a user successfully authenticated
|
||||
func (v *visitor) Authenticated() bool {
|
||||
v.mu.Lock()
|
||||
|
@ -268,7 +275,7 @@ func (v *visitor) SetUser(u *user.User) {
|
|||
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
|
||||
v.user = u
|
||||
if shouldResetLimiters {
|
||||
v.resetLimitersNoLock(0, 0)
|
||||
v.resetLimitersNoLock(0, 0, true)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -283,7 +290,7 @@ func (v *visitor) MaybeUserID() string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (v *visitor) resetLimitersNoLock(messages, emails int64) {
|
||||
func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) {
|
||||
log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
|
||||
limits := v.limitsNoLock()
|
||||
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
|
||||
|
@ -295,6 +302,13 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64) {
|
|||
} else {
|
||||
v.accountLimiter = nil // Users cannot create accounts when logged in
|
||||
}
|
||||
/*
|
||||
if enqueueUpdate && v.user != nil {
|
||||
go v.userManager.EnqueueStats(v.user.ID, &user.Stats{
|
||||
Messages: messages,
|
||||
Emails: emails,
|
||||
})
|
||||
}*/
|
||||
}
|
||||
|
||||
func (v *visitor) Limits() *visitorLimits {
|
||||
|
@ -361,7 +375,7 @@ func (v *visitor) Info() (*visitorInfo, error) {
|
|||
if u != nil {
|
||||
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(u.ID)
|
||||
} else {
|
||||
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String())
|
||||
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.IP().String())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue