Add "last access" to access tokens
This commit is contained in:
parent
000bf27c87
commit
e596834096
15 changed files with 276 additions and 145 deletions
|
@ -5,6 +5,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
"heckel.io/ntfy/client"
|
"heckel.io/ntfy/client"
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -13,7 +14,7 @@ import (
|
||||||
// This only contains helpers so far
|
// This only contains helpers so far
|
||||||
|
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
// log.SetOutput(io.Discard)
|
log.SetLevel(log.WarnLevel)
|
||||||
os.Exit(m.Run())
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,6 +77,7 @@ type Config struct {
|
||||||
AuthStartupQueries string
|
AuthStartupQueries string
|
||||||
AuthDefault user.Permission
|
AuthDefault user.Permission
|
||||||
AuthBcryptCost int
|
AuthBcryptCost int
|
||||||
|
AuthStatsQueueWriterInterval time.Duration
|
||||||
AttachmentCacheDir string
|
AttachmentCacheDir string
|
||||||
AttachmentTotalSizeLimit int64
|
AttachmentTotalSizeLimit int64
|
||||||
AttachmentFileSizeLimit int64
|
AttachmentFileSizeLimit int64
|
||||||
|
@ -145,6 +146,7 @@ func NewConfig() *Config {
|
||||||
AuthStartupQueries: "",
|
AuthStartupQueries: "",
|
||||||
AuthDefault: user.NewPermission(true, true),
|
AuthDefault: user.NewPermission(true, true),
|
||||||
AuthBcryptCost: user.DefaultUserPasswordBcryptCost,
|
AuthBcryptCost: user.DefaultUserPasswordBcryptCost,
|
||||||
|
AuthStatsQueueWriterInterval: user.DefaultUserStatsQueueWriterInterval,
|
||||||
AttachmentCacheDir: "",
|
AttachmentCacheDir: "",
|
||||||
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
|
AttachmentTotalSizeLimit: DefaultAttachmentTotalSizeLimit,
|
||||||
AttachmentFileSizeLimit: DefaultAttachmentFileSizeLimit,
|
AttachmentFileSizeLimit: DefaultAttachmentFileSizeLimit,
|
||||||
|
|
|
@ -171,7 +171,7 @@ func New(conf *Config) (*Server, error) {
|
||||||
}
|
}
|
||||||
var userManager *user.Manager
|
var userManager *user.Manager
|
||||||
if conf.AuthFile != "" {
|
if conf.AuthFile != "" {
|
||||||
userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, user.DefaultUserStatsQueueWriterInterval)
|
userManager, err = user.NewManager(conf.AuthFile, conf.AuthStartupQueries, conf.AuthDefault, conf.AuthBcryptCost, conf.AuthStatsQueueWriterInterval)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -598,7 +598,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
}
|
}
|
||||||
u := v.User()
|
u := v.User()
|
||||||
if s.userManager != nil && u != nil && u.Tier != nil {
|
if s.userManager != nil && u != nil && u.Tier != nil {
|
||||||
s.userManager.EnqueueStats(u.ID, v.Stats())
|
go s.userManager.EnqueueStats(u.ID, v.Stats())
|
||||||
}
|
}
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
s.messages++
|
s.messages++
|
||||||
|
@ -1620,7 +1620,7 @@ func (s *Server) authenticate(r *http.Request) (user *user.User, err error) {
|
||||||
return nil, errHTTPUnauthorized
|
return nil, errHTTPUnauthorized
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(value, "Bearer") {
|
if strings.HasPrefix(value, "Bearer") {
|
||||||
return s.authenticateBearerAuth(value)
|
return s.authenticateBearerAuth(r, value)
|
||||||
}
|
}
|
||||||
return s.authenticateBasicAuth(r, value)
|
return s.authenticateBasicAuth(r, value)
|
||||||
}
|
}
|
||||||
|
@ -1634,9 +1634,18 @@ func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *use
|
||||||
return s.userManager.Authenticate(username, password)
|
return s.userManager.Authenticate(username, password)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) authenticateBearerAuth(value string) (user *user.User, err error) {
|
func (s *Server) authenticateBearerAuth(r *http.Request, value string) (*user.User, error) {
|
||||||
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
|
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
|
||||||
return s.userManager.AuthenticateToken(token)
|
u, err := s.userManager.AuthenticateToken(token)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ip := extractIPAddress(r, s.config.BehindProxy)
|
||||||
|
go s.userManager.EnqueueTokenUpdate(token, &user.TokenUpdate{
|
||||||
|
LastAccess: time.Now(),
|
||||||
|
LastOrigin: ip,
|
||||||
|
})
|
||||||
|
return u, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
func (s *Server) visitor(ip netip.Addr, user *user.User) *visitor {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -122,10 +123,16 @@ func (s *Server) handleAccountGet(w http.ResponseWriter, _ *http.Request, v *vis
|
||||||
if len(tokens) > 0 {
|
if len(tokens) > 0 {
|
||||||
response.Tokens = make([]*apiAccountTokenResponse, 0)
|
response.Tokens = make([]*apiAccountTokenResponse, 0)
|
||||||
for _, t := range tokens {
|
for _, t := range tokens {
|
||||||
|
var lastOrigin string
|
||||||
|
if t.LastOrigin != netip.IPv4Unspecified() {
|
||||||
|
lastOrigin = t.LastOrigin.String()
|
||||||
|
}
|
||||||
response.Tokens = append(response.Tokens, &apiAccountTokenResponse{
|
response.Tokens = append(response.Tokens, &apiAccountTokenResponse{
|
||||||
Token: t.Value,
|
Token: t.Value,
|
||||||
Label: t.Label,
|
Label: t.Label,
|
||||||
Expires: t.Expires.Unix(),
|
LastAccess: t.LastAccess.Unix(),
|
||||||
|
LastOrigin: lastOrigin,
|
||||||
|
Expires: t.Expires.Unix(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -192,14 +199,16 @@ func (s *Server) handleAccountTokenCreate(w http.ResponseWriter, r *http.Request
|
||||||
if req.Expires != nil {
|
if req.Expires != nil {
|
||||||
expires = time.Unix(*req.Expires, 0)
|
expires = time.Unix(*req.Expires, 0)
|
||||||
}
|
}
|
||||||
token, err := s.userManager.CreateToken(v.User().ID, label, expires)
|
token, err := s.userManager.CreateToken(v.User().ID, label, expires, v.IP())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
response := &apiAccountTokenResponse{
|
response := &apiAccountTokenResponse{
|
||||||
Token: token.Value,
|
Token: token.Value,
|
||||||
Label: token.Label,
|
Label: token.Label,
|
||||||
Expires: token.Expires.Unix(),
|
LastAccess: token.LastAccess.Unix(),
|
||||||
|
LastOrigin: token.LastOrigin.String(),
|
||||||
|
Expires: token.Expires.Unix(),
|
||||||
}
|
}
|
||||||
return s.writeJSON(w, response)
|
return s.writeJSON(w, response)
|
||||||
}
|
}
|
||||||
|
@ -228,9 +237,11 @@ func (s *Server) handleAccountTokenUpdate(w http.ResponseWriter, r *http.Request
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
response := &apiAccountTokenResponse{
|
response := &apiAccountTokenResponse{
|
||||||
Token: token.Value,
|
Token: token.Value,
|
||||||
Label: token.Label,
|
Label: token.Label,
|
||||||
Expires: token.Expires.Unix(),
|
LastAccess: token.LastAccess.Unix(),
|
||||||
|
LastOrigin: token.LastOrigin.String(),
|
||||||
|
Expires: token.Expires.Unix(),
|
||||||
}
|
}
|
||||||
return s.writeJSON(w, response)
|
return s.writeJSON(w, response)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"heckel.io/ntfy/user"
|
"heckel.io/ntfy/user"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"io"
|
"io"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
@ -28,6 +29,10 @@ func TestAccount_Signup_Success(t *testing.T) {
|
||||||
token, _ := util.UnmarshalJSON[apiAccountTokenResponse](io.NopCloser(rr.Body))
|
token, _ := util.UnmarshalJSON[apiAccountTokenResponse](io.NopCloser(rr.Body))
|
||||||
require.NotEmpty(t, token.Token)
|
require.NotEmpty(t, token.Token)
|
||||||
require.True(t, time.Now().Add(71*time.Hour).Unix() < token.Expires)
|
require.True(t, time.Now().Add(71*time.Hour).Unix() < token.Expires)
|
||||||
|
require.True(t, strings.HasPrefix(token.Token, "tk_"))
|
||||||
|
require.Equal(t, "9.9.9.9", token.LastOrigin)
|
||||||
|
require.True(t, token.LastAccess > time.Now().Unix()-1)
|
||||||
|
require.True(t, token.LastAccess < time.Now().Unix()+1)
|
||||||
|
|
||||||
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
|
rr = request(t, s, "GET", "/v1/account", "", map[string]string{
|
||||||
"Authorization": util.BearerAuth(token.Token),
|
"Authorization": util.BearerAuth(token.Token),
|
||||||
|
@ -161,7 +166,7 @@ func TestAccount_ChangeSettings(t *testing.T) {
|
||||||
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
u, _ := s.userManager.User("phil")
|
u, _ := s.userManager.User("phil")
|
||||||
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0))
|
token, _ := s.userManager.CreateToken(u.ID, "", time.Unix(0, 0), netip.IPv4Unspecified())
|
||||||
|
|
||||||
rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{
|
rr := request(t, s, "PATCH", "/v1/account/settings", `{"notification": {"sound": "juntos"},"ignored": true}`, map[string]string{
|
||||||
"Authorization": util.BasicAuth("phil", "phil"),
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
@ -558,7 +563,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
// Subscribe anonymously
|
// Subscribe anonymously
|
||||||
anonCh, userCh := make(chan bool), make(chan bool)
|
anonCh, userCh := make(chan bool), make(chan bool)
|
||||||
go func() {
|
go func() {
|
||||||
rr := request(t, s, "GET", "/mytopic/json", ``, nil)
|
rr := request(t, s, "GET", "/mytopic/json", ``, nil) // This blocks until it's killed!
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
messages := toMessages(t, rr.Body.String())
|
messages := toMessages(t, rr.Body.String())
|
||||||
require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message!
|
require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message!
|
||||||
|
@ -570,7 +575,7 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
|
|
||||||
// Subscribe with user
|
// Subscribe with user
|
||||||
go func() {
|
go func() {
|
||||||
rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{
|
rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{ // Blocks!
|
||||||
"Authorization": util.BasicAuth("phil", "mypass"),
|
"Authorization": util.BasicAuth("phil", "mypass"),
|
||||||
})
|
})
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
@ -584,10 +589,10 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Publish message (before reservation)
|
// Publish message (before reservation)
|
||||||
time.Sleep(time.Second) // Wait for subscribers
|
time.Sleep(2 * time.Second) // Wait for subscribers
|
||||||
rr = request(t, s, "POST", "/mytopic", "message before reservation", nil)
|
rr = request(t, s, "POST", "/mytopic", "message before reservation", nil)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
time.Sleep(time.Second) // Wait for subscribers to receive message
|
time.Sleep(2 * time.Second) // Wait for subscribers to receive message
|
||||||
|
|
||||||
// Reserve a topic
|
// Reserve a topic
|
||||||
rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{
|
rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{
|
||||||
|
@ -596,7 +601,11 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
|
|
||||||
// Everyone but phil should be killed
|
// Everyone but phil should be killed
|
||||||
<-anonCh
|
select {
|
||||||
|
case <-anonCh:
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("Waiting for anonymous subscription to be killed failed")
|
||||||
|
}
|
||||||
|
|
||||||
// Publish a message
|
// Publish a message
|
||||||
rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{
|
rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{
|
||||||
|
@ -606,62 +615,10 @@ func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||||
|
|
||||||
// Kill user Go routine
|
// Kill user Go routine
|
||||||
s.topics["mytopic"].CancelSubscribers("<invalid>")
|
s.topics["mytopic"].CancelSubscribers("<invalid>")
|
||||||
<-userCh
|
|
||||||
}
|
select {
|
||||||
|
case <-userCh:
|
||||||
func TestAccount_Tier_Create(t *testing.T) {
|
case <-time.After(5 * time.Second):
|
||||||
conf := newTestConfigWithAuthFile(t)
|
t.Fatal("Waiting for user subscription to be killed failed")
|
||||||
s := newTestServer(t, conf)
|
}
|
||||||
|
|
||||||
// Create tier and user
|
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
|
||||||
Code: "pro",
|
|
||||||
Name: "Pro",
|
|
||||||
MessageLimit: 123,
|
|
||||||
MessageExpiryDuration: 86400 * time.Second,
|
|
||||||
EmailLimit: 32,
|
|
||||||
ReservationLimit: 2,
|
|
||||||
AttachmentFileSizeLimit: 1231231,
|
|
||||||
AttachmentTotalSizeLimit: 123123,
|
|
||||||
AttachmentExpiryDuration: 10800 * time.Second,
|
|
||||||
AttachmentBandwidthLimit: 21474836480,
|
|
||||||
}))
|
|
||||||
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
|
||||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
|
||||||
|
|
||||||
ti, err := s.userManager.Tier("pro")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
u, err := s.userManager.User("phil")
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// These are populated by different SQL queries
|
|
||||||
require.Equal(t, ti, u.Tier)
|
|
||||||
|
|
||||||
// Fields
|
|
||||||
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
|
||||||
require.Equal(t, "pro", ti.Code)
|
|
||||||
require.Equal(t, "Pro", ti.Name)
|
|
||||||
require.Equal(t, int64(123), ti.MessageLimit)
|
|
||||||
require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
|
|
||||||
require.Equal(t, int64(32), ti.EmailLimit)
|
|
||||||
require.Equal(t, int64(2), ti.ReservationLimit)
|
|
||||||
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
|
||||||
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
|
||||||
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
|
||||||
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
|
||||||
conf := newTestConfigWithAuthFile(t)
|
|
||||||
s := newTestServer(t, conf)
|
|
||||||
|
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
|
||||||
ID: "ti_123",
|
|
||||||
Code: "pro",
|
|
||||||
}))
|
|
||||||
|
|
||||||
ti, err := s.userManager.Tier("pro")
|
|
||||||
require.Nil(t, err)
|
|
||||||
require.Equal(t, "ti_123", ti.ID)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -258,11 +258,6 @@ func TestPayments_Checkout_Success_And_Increase_Rate_Limits_Reset_Visitor(t *tes
|
||||||
c.StripeWebhookKey = "webhook key"
|
c.StripeWebhookKey = "webhook key"
|
||||||
c.VisitorRequestLimitBurst = 5
|
c.VisitorRequestLimitBurst = 5
|
||||||
c.VisitorRequestLimitReplenish = time.Hour
|
c.VisitorRequestLimitReplenish = time.Hour
|
||||||
c.CacheStartupQueries = `
|
|
||||||
pragma journal_mode = WAL;
|
|
||||||
pragma synchronous = normal;
|
|
||||||
pragma temp_store = memory;
|
|
||||||
`
|
|
||||||
c.CacheBatchSize = 500
|
c.CacheBatchSize = 500
|
||||||
c.CacheBatchTimeout = time.Second
|
c.CacheBatchTimeout = time.Second
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
|
@ -324,6 +319,18 @@ pragma temp_store = memory;
|
||||||
})
|
})
|
||||||
require.Equal(t, 429, rr.Code)
|
require.Equal(t, 429, rr.Code)
|
||||||
|
|
||||||
|
// Verify some "before-stats"
|
||||||
|
u, err = s.userManager.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, u.Tier)
|
||||||
|
require.Equal(t, "", u.Billing.StripeCustomerID)
|
||||||
|
require.Equal(t, "", u.Billing.StripeSubscriptionID)
|
||||||
|
require.Equal(t, stripe.SubscriptionStatus(""), u.Billing.StripeSubscriptionStatus)
|
||||||
|
require.Equal(t, int64(0), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||||
|
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||||
|
require.Equal(t, int64(0), u.Stats.Messages) // Messages and emails are not persisted for no-tier users!
|
||||||
|
require.Equal(t, int64(0), u.Stats.Emails)
|
||||||
|
|
||||||
// Simulate Stripe success return URL call (no user context)
|
// Simulate Stripe success return URL call (no user context)
|
||||||
rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
|
rr = request(t, s, "GET", "/v1/account/billing/subscription/success/SOMETOKEN", "", nil)
|
||||||
require.Equal(t, 303, rr.Code)
|
require.Equal(t, 303, rr.Code)
|
||||||
|
@ -337,6 +344,8 @@ pragma temp_store = memory;
|
||||||
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
|
require.Equal(t, stripe.SubscriptionStatusActive, u.Billing.StripeSubscriptionStatus)
|
||||||
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
require.Equal(t, int64(123456789), u.Billing.StripeSubscriptionPaidUntil.Unix())
|
||||||
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
require.Equal(t, int64(0), u.Billing.StripeSubscriptionCancelAt.Unix())
|
||||||
|
require.Equal(t, int64(0), u.Stats.Messages)
|
||||||
|
require.Equal(t, int64(0), u.Stats.Emails)
|
||||||
|
|
||||||
// Now for the fun part: Verify that new rate limits are immediately applied
|
// Now for the fun part: Verify that new rate limits are immediately applied
|
||||||
// This only tests the request limiter, which kicks in before the message limiter.
|
// This only tests the request limiter, which kicks in before the message limiter.
|
||||||
|
|
|
@ -892,10 +892,8 @@ func TestServer_DailyMessageQuotaFromDatabase(t *testing.T) {
|
||||||
// if the visitor is unknown
|
// if the visitor is unknown
|
||||||
|
|
||||||
c := newTestConfigWithAuthFile(t)
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.AuthStatsQueueWriterInterval = 100 * time.Millisecond
|
||||||
s := newTestServer(t, c)
|
s := newTestServer(t, c)
|
||||||
var err error
|
|
||||||
s.userManager, err = user.NewManager(c.AuthFile, c.AuthStartupQueries, c.AuthDefault, c.AuthBcryptCost, 100*time.Millisecond)
|
|
||||||
require.Nil(t, err)
|
|
||||||
|
|
||||||
// Create user, and update it with some message and email stats
|
// Create user, and update it with some message and email stats
|
||||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||||
|
|
|
@ -247,9 +247,11 @@ type apiAccountTokenUpdateRequest struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type apiAccountTokenResponse struct {
|
type apiAccountTokenResponse struct {
|
||||||
Token string `json:"token"`
|
Token string `json:"token"`
|
||||||
Label string `json:"label,omitempty"`
|
Label string `json:"label,omitempty"`
|
||||||
Expires int64 `json:"expires,omitempty"` // Unix timestamp
|
LastAccess int64 `json:"last_access,omitempty"`
|
||||||
|
LastOrigin string `json:"last_origin,omitempty"`
|
||||||
|
Expires int64 `json:"expires,omitempty"` // Unix timestamp
|
||||||
}
|
}
|
||||||
|
|
||||||
type apiAccountTier struct {
|
type apiAccountTier struct {
|
||||||
|
|
|
@ -131,7 +131,7 @@ func newVisitor(conf *Config, messageCache *messageCache, userManager *user.Mana
|
||||||
bandwidthLimiter: nil, // Set in resetLimiters
|
bandwidthLimiter: nil, // Set in resetLimiters
|
||||||
accountLimiter: nil, // Set in resetLimiters, may be nil
|
accountLimiter: nil, // Set in resetLimiters, may be nil
|
||||||
}
|
}
|
||||||
v.resetLimitersNoLock(messages, emails)
|
v.resetLimitersNoLock(messages, emails, false)
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -254,6 +254,13 @@ func (v *visitor) User() *user.User {
|
||||||
return v.user // May be nil
|
return v.user // May be nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IP returns the visitor IP address
|
||||||
|
func (v *visitor) IP() netip.Addr {
|
||||||
|
v.mu.Lock()
|
||||||
|
defer v.mu.Unlock()
|
||||||
|
return v.ip
|
||||||
|
}
|
||||||
|
|
||||||
// Authenticated returns true if a user successfully authenticated
|
// Authenticated returns true if a user successfully authenticated
|
||||||
func (v *visitor) Authenticated() bool {
|
func (v *visitor) Authenticated() bool {
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
|
@ -268,7 +275,7 @@ func (v *visitor) SetUser(u *user.User) {
|
||||||
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
|
shouldResetLimiters := v.user.TierID() != u.TierID() // TierID works with nil receiver
|
||||||
v.user = u
|
v.user = u
|
||||||
if shouldResetLimiters {
|
if shouldResetLimiters {
|
||||||
v.resetLimitersNoLock(0, 0)
|
v.resetLimitersNoLock(0, 0, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -283,7 +290,7 @@ func (v *visitor) MaybeUserID() string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) resetLimitersNoLock(messages, emails int64) {
|
func (v *visitor) resetLimitersNoLock(messages, emails int64, enqueueUpdate bool) {
|
||||||
log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
|
log.Debug("%s Resetting limiters for visitor", v.stringNoLock())
|
||||||
limits := v.limitsNoLock()
|
limits := v.limitsNoLock()
|
||||||
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
|
v.requestLimiter = rate.NewLimiter(limits.RequestLimitReplenish, limits.RequestLimitBurst)
|
||||||
|
@ -295,6 +302,13 @@ func (v *visitor) resetLimitersNoLock(messages, emails int64) {
|
||||||
} else {
|
} else {
|
||||||
v.accountLimiter = nil // Users cannot create accounts when logged in
|
v.accountLimiter = nil // Users cannot create accounts when logged in
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
|
if enqueueUpdate && v.user != nil {
|
||||||
|
go v.userManager.EnqueueStats(v.user.ID, &user.Stats{
|
||||||
|
Messages: messages,
|
||||||
|
Emails: emails,
|
||||||
|
})
|
||||||
|
}*/
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) Limits() *visitorLimits {
|
func (v *visitor) Limits() *visitorLimits {
|
||||||
|
@ -361,7 +375,7 @@ func (v *visitor) Info() (*visitorInfo, error) {
|
||||||
if u != nil {
|
if u != nil {
|
||||||
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(u.ID)
|
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedByUser(u.ID)
|
||||||
} else {
|
} else {
|
||||||
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.ip.String())
|
attachmentsBytesUsed, err = v.messageCache.AttachmentBytesUsedBySender(v.IP().String())
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
111
user/manager.go
111
user/manager.go
|
@ -10,6 +10,7 @@ import (
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"heckel.io/ntfy/log"
|
"heckel.io/ntfy/log"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -95,6 +96,8 @@ const (
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
token TEXT NOT NULL,
|
token TEXT NOT NULL,
|
||||||
label TEXT NOT NULL,
|
label TEXT NOT NULL,
|
||||||
|
last_access INT NOT NULL,
|
||||||
|
last_origin TEXT NOT NULL,
|
||||||
expires INT NOT NULL,
|
expires INT NOT NULL,
|
||||||
PRIMARY KEY (user_id, token),
|
PRIMARY KEY (user_id, token),
|
||||||
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
|
||||||
|
@ -127,9 +130,9 @@ const (
|
||||||
selectUserByTokenQuery = `
|
selectUserByTokenQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||||
FROM user u
|
FROM user u
|
||||||
JOIN user_token t on u.id = t.user_id
|
JOIN user_token tk on u.id = tk.user_id
|
||||||
LEFT JOIN tier t on t.id = u.tier_id
|
LEFT JOIN tier t on t.id = u.tier_id
|
||||||
WHERE t.token = ? AND (t.expires = 0 OR t.expires >= ?)
|
WHERE tk.token = ? AND (tk.expires = 0 OR tk.expires >= ?)
|
||||||
`
|
`
|
||||||
selectUserByStripeCustomerIDQuery = `
|
selectUserByStripeCustomerIDQuery = `
|
||||||
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
SELECT u.id, u.user, u.pass, u.role, u.prefs, u.sync_topic, u.stats_messages, u.stats_emails, u.stripe_customer_id, u.stripe_subscription_id, u.stripe_subscription_status, u.stripe_subscription_paid_until, u.stripe_subscription_cancel_at, deleted, t.id, t.code, t.name, t.messages_limit, t.messages_expiry_duration, t.emails_limit, t.reservations_limit, t.attachment_file_size_limit, t.attachment_total_size_limit, t.attachment_expiry_duration, t.attachment_bandwidth_limit, t.stripe_price_id
|
||||||
|
@ -218,16 +221,17 @@ const (
|
||||||
AND topic = ?
|
AND topic = ?
|
||||||
`
|
`
|
||||||
|
|
||||||
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
selectTokenCountQuery = `SELECT COUNT(*) FROM user_token WHERE user_id = ?`
|
||||||
selectTokensQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ?`
|
selectTokensQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ?`
|
||||||
selectTokenQuery = `SELECT token, label, expires FROM user_token WHERE user_id = ? AND token = ?`
|
selectTokenQuery = `SELECT token, label, last_access, last_origin, expires FROM user_token WHERE user_id = ? AND token = ?`
|
||||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, expires) VALUES (?, ?, ?, ?)`
|
insertTokenQuery = `INSERT INTO user_token (user_id, token, label, last_access, last_origin, expires) VALUES (?, ?, ?, ?, ?, ?)`
|
||||||
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
updateTokenExpiryQuery = `UPDATE user_token SET expires = ? WHERE user_id = ? AND token = ?`
|
||||||
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
updateTokenLabelQuery = `UPDATE user_token SET label = ? WHERE user_id = ? AND token = ?`
|
||||||
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
updateTokenLastAccessQuery = `UPDATE user_token SET last_access = ?, last_origin = ? WHERE token = ?`
|
||||||
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
deleteTokenQuery = `DELETE FROM user_token WHERE user_id = ? AND token = ?`
|
||||||
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
deleteAllTokenQuery = `DELETE FROM user_token WHERE user_id = ?`
|
||||||
deleteExcessTokensQuery = `
|
deleteExpiredTokensQuery = `DELETE FROM user_token WHERE expires > 0 AND expires < ?`
|
||||||
|
deleteExcessTokensQuery = `
|
||||||
DELETE FROM user_token
|
DELETE FROM user_token
|
||||||
WHERE (user_id, token) NOT IN (
|
WHERE (user_id, token) NOT IN (
|
||||||
SELECT user_id, token
|
SELECT user_id, token
|
||||||
|
@ -297,16 +301,17 @@ const (
|
||||||
// in a SQLite database.
|
// in a SQLite database.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
defaultAccess Permission // Default permission if no ACL matches
|
defaultAccess Permission // Default permission if no ACL matches
|
||||||
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
statsQueue map[string]*Stats // "Queue" to asynchronously write user stats to the database (UserID -> Stats)
|
||||||
bcryptCost int // Makes testing easier
|
tokenQueue map[string]*TokenUpdate // "Queue" to asynchronously write token access stats to the database (Token ID -> TokenUpdate)
|
||||||
|
bcryptCost int // Makes testing easier
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ Auther = (*Manager)(nil)
|
var _ Auther = (*Manager)(nil)
|
||||||
|
|
||||||
// NewManager creates a new Manager instance
|
// NewManager creates a new Manager instance
|
||||||
func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, statsWriterInterval time.Duration) (*Manager, error) {
|
func NewManager(filename, startupQueries string, defaultAccess Permission, bcryptCost int, queueWriterInterval time.Duration) (*Manager, error) {
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -321,9 +326,10 @@ func NewManager(filename, startupQueries string, defaultAccess Permission, bcryp
|
||||||
db: db,
|
db: db,
|
||||||
defaultAccess: defaultAccess,
|
defaultAccess: defaultAccess,
|
||||||
statsQueue: make(map[string]*Stats),
|
statsQueue: make(map[string]*Stats),
|
||||||
|
tokenQueue: make(map[string]*TokenUpdate),
|
||||||
bcryptCost: bcryptCost,
|
bcryptCost: bcryptCost,
|
||||||
}
|
}
|
||||||
go manager.userStatsQueueWriter(statsWriterInterval)
|
go manager.asyncQueueWriter(queueWriterInterval)
|
||||||
return manager, nil
|
return manager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -367,14 +373,15 @@ func (a *Manager) AuthenticateToken(token string) (*User, error) {
|
||||||
// CreateToken generates a random token for the given user and returns it. The token expires
|
// CreateToken generates a random token for the given user and returns it. The token expires
|
||||||
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
// after a fixed duration unless ChangeToken is called. This function also prunes tokens for the
|
||||||
// given user, if there are too many of them.
|
// given user, if there are too many of them.
|
||||||
func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token, error) {
|
func (a *Manager) CreateToken(userID, label string, expires time.Time, origin netip.Addr) (*Token, error) {
|
||||||
token := util.RandomStringPrefix(tokenPrefix, tokenLength)
|
token := util.RandomStringPrefix(tokenPrefix, tokenLength)
|
||||||
tx, err := a.db.Begin()
|
tx, err := a.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer tx.Rollback()
|
||||||
if _, err := tx.Exec(insertTokenQuery, userID, token, label, expires.Unix()); err != nil {
|
access := time.Now()
|
||||||
|
if _, err := tx.Exec(insertTokenQuery, userID, token, label, access.Unix(), origin.String(), expires.Unix()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
rows, err := tx.Query(selectTokenCountQuery, userID)
|
rows, err := tx.Query(selectTokenCountQuery, userID)
|
||||||
|
@ -400,9 +407,11 @@ func (a *Manager) CreateToken(userID, label string, expires time.Time) (*Token,
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &Token{
|
return &Token{
|
||||||
Value: token,
|
Value: token,
|
||||||
Label: label,
|
Label: label,
|
||||||
Expires: expires,
|
LastAccess: access,
|
||||||
|
LastOrigin: origin,
|
||||||
|
Expires: expires,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -437,20 +446,26 @@ func (a *Manager) Token(userID, token string) (*Token, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
func (a *Manager) readToken(rows *sql.Rows) (*Token, error) {
|
||||||
var token, label string
|
var token, label, lastOrigin string
|
||||||
var expires int64
|
var lastAccess, expires int64
|
||||||
if !rows.Next() {
|
if !rows.Next() {
|
||||||
return nil, ErrTokenNotFound
|
return nil, ErrTokenNotFound
|
||||||
}
|
}
|
||||||
if err := rows.Scan(&token, &label, &expires); err != nil {
|
if err := rows.Scan(&token, &label, &lastAccess, &lastOrigin, &expires); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if err := rows.Err(); err != nil {
|
} else if err := rows.Err(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
lastOriginIP, err := netip.ParseAddr(lastOrigin)
|
||||||
|
if err != nil {
|
||||||
|
lastOriginIP = netip.IPv4Unspecified()
|
||||||
|
}
|
||||||
return &Token{
|
return &Token{
|
||||||
Value: token,
|
Value: token,
|
||||||
Label: label,
|
Label: label,
|
||||||
Expires: time.Unix(expires, 0),
|
LastAccess: time.Unix(lastAccess, 0),
|
||||||
|
LastOrigin: lastOriginIP,
|
||||||
|
Expires: time.Unix(expires, 0),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -521,7 +536,7 @@ func (a *Manager) ChangeSettings(user *User) error {
|
||||||
|
|
||||||
// ResetStats resets all user stats in the user database. This touches all users.
|
// ResetStats resets all user stats in the user database. This touches all users.
|
||||||
func (a *Manager) ResetStats() error {
|
func (a *Manager) ResetStats() error {
|
||||||
a.mu.Lock()
|
a.mu.Lock() // Includes database query to avoid races!
|
||||||
defer a.mu.Unlock()
|
defer a.mu.Unlock()
|
||||||
if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
|
if _, err := a.db.Exec(updateUserStatsResetAllQuery); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -538,12 +553,23 @@ func (a *Manager) EnqueueStats(userID string, stats *Stats) {
|
||||||
a.statsQueue[userID] = stats
|
a.statsQueue[userID] = stats
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Manager) userStatsQueueWriter(interval time.Duration) {
|
// EnqueueTokenUpdate adds the token update to a queue which writes out token access times
|
||||||
|
// in batches at a regular interval
|
||||||
|
func (a *Manager) EnqueueTokenUpdate(tokenID string, update *TokenUpdate) {
|
||||||
|
a.mu.Lock()
|
||||||
|
defer a.mu.Unlock()
|
||||||
|
a.tokenQueue[tokenID] = update
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Manager) asyncQueueWriter(interval time.Duration) {
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
if err := a.writeUserStatsQueue(); err != nil {
|
if err := a.writeUserStatsQueue(); err != nil {
|
||||||
log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
|
log.Warn("User Manager: Writing user stats queue failed: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
if err := a.writeTokenUpdateQueue(); err != nil {
|
||||||
|
log.Warn("User Manager: Writing token update queue failed: %s", err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -572,6 +598,31 @@ func (a *Manager) writeUserStatsQueue() error {
|
||||||
return tx.Commit()
|
return tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Manager) writeTokenUpdateQueue() error {
|
||||||
|
a.mu.Lock()
|
||||||
|
if len(a.tokenQueue) == 0 {
|
||||||
|
a.mu.Unlock()
|
||||||
|
log.Trace("User Manager: No token updates to commit")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
tokenQueue := a.tokenQueue
|
||||||
|
a.tokenQueue = make(map[string]*TokenUpdate)
|
||||||
|
a.mu.Unlock()
|
||||||
|
tx, err := a.db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
log.Debug("User Manager: Writing token update queue for %d token(s)", len(tokenQueue))
|
||||||
|
for tokenID, update := range tokenQueue {
|
||||||
|
log.Trace("User Manager: Updating token %s with last access time %v", tokenID, update.LastAccess.Unix())
|
||||||
|
if _, err := tx.Exec(updateTokenLastAccessQuery, update.LastAccess.Unix(), update.LastOrigin.String(), tokenID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
// Authorize returns nil if the given user has access to the given topic using the desired
|
// Authorize returns nil if the given user has access to the given topic using the desired
|
||||||
// permission. The user param may be nil to signal an anonymous user.
|
// permission. The user param may be nil to signal an anonymous user.
|
||||||
func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
|
func (a *Manager) Authorize(user *User, topic string, perm Permission) error {
|
||||||
|
|
|
@ -5,6 +5,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
|
"net/netip"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
@ -139,7 +140,7 @@ func TestManager_MarkUserRemoved_RemoveDeletedUsers(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.False(t, u.Deleted)
|
require.False(t, u.Deleted)
|
||||||
|
|
||||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour))
|
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
u, err = a.Authenticate("user", "pass")
|
u, err = a.Authenticate("user", "pass")
|
||||||
|
@ -397,7 +398,7 @@ func TestManager_Token_Valid(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Create token for user
|
// Create token for user
|
||||||
token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour))
|
token, err := a.CreateToken(u.ID, "some label", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.NotEmpty(t, token.Value)
|
require.NotEmpty(t, token.Value)
|
||||||
require.Equal(t, "some label", token.Label)
|
require.Equal(t, "some label", token.Label)
|
||||||
|
@ -441,12 +442,12 @@ func TestManager_Token_Expire(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
||||||
// Create tokens for user
|
// Create tokens for user
|
||||||
token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour))
|
token1, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.NotEmpty(t, token1.Value)
|
require.NotEmpty(t, token1.Value)
|
||||||
require.True(t, time.Now().Add(71*time.Hour).Unix() < token1.Expires.Unix())
|
require.True(t, time.Now().Add(71*time.Hour).Unix() < token1.Expires.Unix())
|
||||||
|
|
||||||
token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour))
|
token2, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.NotEmpty(t, token2.Value)
|
require.NotEmpty(t, token2.Value)
|
||||||
require.NotEqual(t, token1.Value, token2.Value)
|
require.NotEqual(t, token1.Value, token2.Value)
|
||||||
|
@ -493,7 +494,7 @@ func TestManager_Token_Extend(t *testing.T) {
|
||||||
require.Equal(t, errNoTokenProvided, err)
|
require.Equal(t, errNoTokenProvided, err)
|
||||||
|
|
||||||
// Create token for user
|
// Create token for user
|
||||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour))
|
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.NotEmpty(t, token.Value)
|
require.NotEmpty(t, token.Value)
|
||||||
|
|
||||||
|
@ -520,7 +521,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||||
baseTime := time.Now().Add(24 * time.Hour)
|
baseTime := time.Now().Add(24 * time.Hour)
|
||||||
tokens := make([]string, 0)
|
tokens := make([]string, 0)
|
||||||
for i := 0; i < 12; i++ {
|
for i := 0; i < 12; i++ {
|
||||||
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour))
|
token, err := a.CreateToken(u.ID, "", time.Now().Add(72*time.Hour), netip.IPv4Unspecified())
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.NotEmpty(t, token.Value)
|
require.NotEmpty(t, token.Value)
|
||||||
tokens = append(tokens, token.Value)
|
tokens = append(tokens, token.Value)
|
||||||
|
@ -624,6 +625,61 @@ func TestManager_ChangeSettings(t *testing.T) {
|
||||||
require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
|
require.Equal(t, util.String("My Topic"), u.Prefs.Subscriptions[0].DisplayName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManager_Tier_Create(t *testing.T) {
|
||||||
|
a := newTestManager(t, PermissionDenyAll)
|
||||||
|
|
||||||
|
// Create tier and user
|
||||||
|
require.Nil(t, a.CreateTier(&Tier{
|
||||||
|
Code: "pro",
|
||||||
|
Name: "Pro",
|
||||||
|
MessageLimit: 123,
|
||||||
|
MessageExpiryDuration: 86400 * time.Second,
|
||||||
|
EmailLimit: 32,
|
||||||
|
ReservationLimit: 2,
|
||||||
|
AttachmentFileSizeLimit: 1231231,
|
||||||
|
AttachmentTotalSizeLimit: 123123,
|
||||||
|
AttachmentExpiryDuration: 10800 * time.Second,
|
||||||
|
AttachmentBandwidthLimit: 21474836480,
|
||||||
|
}))
|
||||||
|
require.Nil(t, a.AddUser("phil", "phil", RoleUser))
|
||||||
|
require.Nil(t, a.ChangeTier("phil", "pro"))
|
||||||
|
|
||||||
|
ti, err := a.Tier("pro")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
u, err := a.User("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// These are populated by different SQL queries
|
||||||
|
require.Equal(t, ti, u.Tier)
|
||||||
|
|
||||||
|
// Fields
|
||||||
|
require.True(t, strings.HasPrefix(ti.ID, "ti_"))
|
||||||
|
require.Equal(t, "pro", ti.Code)
|
||||||
|
require.Equal(t, "Pro", ti.Name)
|
||||||
|
require.Equal(t, int64(123), ti.MessageLimit)
|
||||||
|
require.Equal(t, 86400*time.Second, ti.MessageExpiryDuration)
|
||||||
|
require.Equal(t, int64(32), ti.EmailLimit)
|
||||||
|
require.Equal(t, int64(2), ti.ReservationLimit)
|
||||||
|
require.Equal(t, int64(1231231), ti.AttachmentFileSizeLimit)
|
||||||
|
require.Equal(t, int64(123123), ti.AttachmentTotalSizeLimit)
|
||||||
|
require.Equal(t, 10800*time.Second, ti.AttachmentExpiryDuration)
|
||||||
|
require.Equal(t, int64(21474836480), ti.AttachmentBandwidthLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAccount_Tier_Create_With_ID(t *testing.T) {
|
||||||
|
a := newTestManager(t, PermissionDenyAll)
|
||||||
|
|
||||||
|
require.Nil(t, a.CreateTier(&Tier{
|
||||||
|
ID: "ti_123",
|
||||||
|
Code: "pro",
|
||||||
|
}))
|
||||||
|
|
||||||
|
ti, err := a.Tier("pro")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "ti_123", ti.ID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestSqliteCache_Migration_From1(t *testing.T) {
|
func TestSqliteCache_Migration_From1(t *testing.T) {
|
||||||
filename := filepath.Join(t.TempDir(), "user.db")
|
filename := filepath.Join(t.TempDir(), "user.db")
|
||||||
db, err := sql.Open("sqlite3", filename)
|
db, err := sql.Open("sqlite3", filename)
|
||||||
|
|
|
@ -4,6 +4,7 @@ package user
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/stripe/stripe-go/v74"
|
"github.com/stripe/stripe-go/v74"
|
||||||
|
"net/netip"
|
||||||
"regexp"
|
"regexp"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -46,9 +47,17 @@ type Auther interface {
|
||||||
|
|
||||||
// Token represents a user token, including expiry date
|
// Token represents a user token, including expiry date
|
||||||
type Token struct {
|
type Token struct {
|
||||||
Value string
|
Value string
|
||||||
Label string
|
Label string
|
||||||
Expires time.Time
|
LastAccess time.Time
|
||||||
|
LastOrigin netip.Addr
|
||||||
|
Expires time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenUpdate holds information about the last access time and origin IP address of a token
|
||||||
|
type TokenUpdate struct {
|
||||||
|
LastAccess time.Time
|
||||||
|
LastOrigin netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prefs represents a user's configuration settings
|
// Prefs represents a user's configuration settings
|
||||||
|
|
|
@ -226,6 +226,7 @@
|
||||||
"account_tokens_description": "Use access tokens when publishing and subscribing via the ntfy API, so you don't have to send your account credentials. Check out the <Link>documentation</Link> to learn more.",
|
"account_tokens_description": "Use access tokens when publishing and subscribing via the ntfy API, so you don't have to send your account credentials. Check out the <Link>documentation</Link> to learn more.",
|
||||||
"account_tokens_table_token_header": "Token",
|
"account_tokens_table_token_header": "Token",
|
||||||
"account_tokens_table_label_header": "Label",
|
"account_tokens_table_label_header": "Label",
|
||||||
|
"account_tokens_table_last_access_header": "Last access",
|
||||||
"account_tokens_table_expires_header": "Expires",
|
"account_tokens_table_expires_header": "Expires",
|
||||||
"account_tokens_table_never_expires": "Never expires",
|
"account_tokens_table_never_expires": "Never expires",
|
||||||
"account_tokens_table_current_session": "Current browser session",
|
"account_tokens_table_current_session": "Current browser session",
|
||||||
|
@ -233,6 +234,7 @@
|
||||||
"account_tokens_table_copied_to_clipboard": "Access token copied",
|
"account_tokens_table_copied_to_clipboard": "Access token copied",
|
||||||
"account_tokens_table_cannot_delete_or_edit": "Cannot edit or delete current session token",
|
"account_tokens_table_cannot_delete_or_edit": "Cannot edit or delete current session token",
|
||||||
"account_tokens_table_create_token_button": "Create access token",
|
"account_tokens_table_create_token_button": "Create access token",
|
||||||
|
"account_tokens_table_last_origin_tooltip": "From IP address {{ip}}, click to lookup",
|
||||||
"account_tokens_dialog_title_create": "Create access token",
|
"account_tokens_dialog_title_create": "Create access token",
|
||||||
"account_tokens_dialog_title_edit": "Edit access token",
|
"account_tokens_dialog_title_edit": "Edit access token",
|
||||||
"account_tokens_dialog_title_delete": "Delete access token",
|
"account_tokens_dialog_title_delete": "Delete access token",
|
||||||
|
|
|
@ -27,7 +27,7 @@ import DialogContent from "@mui/material/DialogContent";
|
||||||
import TextField from "@mui/material/TextField";
|
import TextField from "@mui/material/TextField";
|
||||||
import routes from "./routes";
|
import routes from "./routes";
|
||||||
import IconButton from "@mui/material/IconButton";
|
import IconButton from "@mui/material/IconButton";
|
||||||
import {formatBytes, formatShortDate, formatShortDateTime, truncateString, validUrl} from "../app/utils";
|
import {formatBytes, formatShortDate, formatShortDateTime, openUrl, truncateString, validUrl} from "../app/utils";
|
||||||
import accountApi, {IncorrectPasswordError, UnauthorizedError} from "../app/AccountApi";
|
import accountApi, {IncorrectPasswordError, UnauthorizedError} from "../app/AccountApi";
|
||||||
import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined';
|
import InfoOutlinedIcon from '@mui/icons-material/InfoOutlined';
|
||||||
import {Pref, PrefGroup} from "./Pref";
|
import {Pref, PrefGroup} from "./Pref";
|
||||||
|
@ -43,7 +43,7 @@ import userManager from "../app/UserManager";
|
||||||
import {Paragraph} from "./styles";
|
import {Paragraph} from "./styles";
|
||||||
import CloseIcon from "@mui/icons-material/Close";
|
import CloseIcon from "@mui/icons-material/Close";
|
||||||
import DialogActions from "@mui/material/DialogActions";
|
import DialogActions from "@mui/material/DialogActions";
|
||||||
import {ContentCopy} from "@mui/icons-material";
|
import {ContentCopy, Public} from "@mui/icons-material";
|
||||||
import MenuItem from "@mui/material/MenuItem";
|
import MenuItem from "@mui/material/MenuItem";
|
||||||
import ListItemIcon from "@mui/material/ListItemIcon";
|
import ListItemIcon from "@mui/material/ListItemIcon";
|
||||||
import {PermissionDenyAll, PermissionRead, PermissionReadWrite, PermissionWrite} from "./ReserveIcons";
|
import {PermissionDenyAll, PermissionRead, PermissionReadWrite, PermissionWrite} from "./ReserveIcons";
|
||||||
|
@ -506,6 +506,7 @@ const TokensTable = (props) => {
|
||||||
<TableCell sx={{paddingLeft: 0}}>{t("account_tokens_table_token_header")}</TableCell>
|
<TableCell sx={{paddingLeft: 0}}>{t("account_tokens_table_token_header")}</TableCell>
|
||||||
<TableCell>{t("account_tokens_table_label_header")}</TableCell>
|
<TableCell>{t("account_tokens_table_label_header")}</TableCell>
|
||||||
<TableCell>{t("account_tokens_table_expires_header")}</TableCell>
|
<TableCell>{t("account_tokens_table_expires_header")}</TableCell>
|
||||||
|
<TableCell>{t("account_tokens_table_last_access_header")}</TableCell>
|
||||||
<TableCell/>
|
<TableCell/>
|
||||||
</TableRow>
|
</TableRow>
|
||||||
</TableHead>
|
</TableHead>
|
||||||
|
@ -513,11 +514,11 @@ const TokensTable = (props) => {
|
||||||
{tokens.map(token => (
|
{tokens.map(token => (
|
||||||
<TableRow
|
<TableRow
|
||||||
key={token.token}
|
key={token.token}
|
||||||
sx={{'&:last-child td, &:last-child th': {border: 0}}}
|
sx={{'&:last-child td, &:last-child th': { border: 0 }}}
|
||||||
>
|
>
|
||||||
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("account_tokens_table_token_header")}>
|
<TableCell component="th" scope="row" sx={{ paddingLeft: 0, whiteSpace: "nowrap" }} aria-label={t("account_tokens_table_token_header")}>
|
||||||
<span>
|
<span>
|
||||||
<span style={{fontFamily: "Monospace", fontSize: "0.9rem"}}>{token.token.slice(0, 20)}</span>
|
<span style={{fontFamily: "Monospace", fontSize: "0.9rem"}}>{token.token.slice(0, 12)}</span>
|
||||||
...
|
...
|
||||||
<Tooltip title={t("account_tokens_table_copy_to_clipboard")} placement="right">
|
<Tooltip title={t("account_tokens_table_copy_to_clipboard")} placement="right">
|
||||||
<IconButton onClick={() => handleCopy(token.token)}><ContentCopy/></IconButton>
|
<IconButton onClick={() => handleCopy(token.token)}><ContentCopy/></IconButton>
|
||||||
|
@ -531,7 +532,17 @@ const TokensTable = (props) => {
|
||||||
<TableCell aria-label={t("account_tokens_table_expires_header")}>
|
<TableCell aria-label={t("account_tokens_table_expires_header")}>
|
||||||
{token.expires ? formatShortDateTime(token.expires) : <em>{t("account_tokens_table_never_expires")}</em>}
|
{token.expires ? formatShortDateTime(token.expires) : <em>{t("account_tokens_table_never_expires")}</em>}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell align="right">
|
<TableCell aria-label={t("account_tokens_table_last_access_header")}>
|
||||||
|
<div style={{ display: "flex", alignItems: "center" }}>
|
||||||
|
<span>{formatShortDateTime(token.last_access)}</span>
|
||||||
|
<Tooltip title={t("account_tokens_table_last_origin_tooltip", { ip: token.last_origin })}>
|
||||||
|
<IconButton onClick={() => openUrl(`https://whatismyipaddress.com/ip/${token.last_origin}`)}>
|
||||||
|
<Public />
|
||||||
|
</IconButton>
|
||||||
|
</Tooltip>
|
||||||
|
</div>
|
||||||
|
</TableCell>
|
||||||
|
<TableCell align="right" sx={{ whiteSpace: "nowrap" }}>
|
||||||
{token.token !== session.token() &&
|
{token.token !== session.token() &&
|
||||||
<>
|
<>
|
||||||
<IconButton onClick={() => handleEditClick(token)} aria-label={t("account_tokens_dialog_title_edit")}>
|
<IconButton onClick={() => handleEditClick(token)} aria-label={t("account_tokens_dialog_title_edit")}>
|
||||||
|
|
|
@ -300,10 +300,9 @@ const UserTable = (props) => {
|
||||||
key={user.baseUrl}
|
key={user.baseUrl}
|
||||||
sx={{'&:last-child td, &:last-child th': {border: 0}}}
|
sx={{'&:last-child td, &:last-child th': {border: 0}}}
|
||||||
>
|
>
|
||||||
<TableCell component="th" scope="row" sx={{paddingLeft: 0}}
|
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_users_table_user_header")}>{user.username}</TableCell>
|
||||||
aria-label={t("prefs_users_table_user_header")}>{user.username}</TableCell>
|
|
||||||
<TableCell aria-label={t("prefs_users_table_base_url_header")}>{user.baseUrl}</TableCell>
|
<TableCell aria-label={t("prefs_users_table_base_url_header")}>{user.baseUrl}</TableCell>
|
||||||
<TableCell align="right">
|
<TableCell align="right" sx={{ whiteSpace: "nowrap" }}>
|
||||||
{(!session.exists() || user.baseUrl !== config.base_url) &&
|
{(!session.exists() || user.baseUrl !== config.base_url) &&
|
||||||
<>
|
<>
|
||||||
<IconButton onClick={() => handleEditClick(user)} aria-label={t("prefs_users_edit_button")}>
|
<IconButton onClick={() => handleEditClick(user)} aria-label={t("prefs_users_edit_button")}>
|
||||||
|
@ -597,7 +596,7 @@ const ReservationsTable = (props) => {
|
||||||
{props.reservations.map(reservation => (
|
{props.reservations.map(reservation => (
|
||||||
<TableRow
|
<TableRow
|
||||||
key={reservation.topic}
|
key={reservation.topic}
|
||||||
sx={{'&:last-child td, &:last-child th': {border: 0}}}
|
sx={{'&:last-child td, &:last-child th': { border: 0 }}}
|
||||||
>
|
>
|
||||||
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_reservations_table_topic_header")}>
|
<TableCell component="th" scope="row" sx={{paddingLeft: 0}} aria-label={t("prefs_reservations_table_topic_header")}>
|
||||||
{reservation.topic}
|
{reservation.topic}
|
||||||
|
@ -628,7 +627,7 @@ const ReservationsTable = (props) => {
|
||||||
</>
|
</>
|
||||||
}
|
}
|
||||||
</TableCell>
|
</TableCell>
|
||||||
<TableCell align="right">
|
<TableCell align="right" sx={{ whiteSpace: "nowrap" }}>
|
||||||
{!localSubscriptions[reservation.topic] &&
|
{!localSubscriptions[reservation.topic] &&
|
||||||
<Chip icon={<Info/>} label="Not subscribed" color="primary" variant="outlined"/>
|
<Chip icon={<Info/>} label="Not subscribed" color="primary" variant="outlined"/>
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue