A few manager tests
This commit is contained in:
parent
73b0161ff7
commit
8bf64d8723
5 changed files with 169 additions and 6 deletions
|
@ -447,7 +447,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
||||||
// Check if we are allowed to reserve this topic
|
// Check if we are allowed to reserve this topic
|
||||||
if u.IsUser() && u.Tier == nil {
|
if u.IsUser() && u.Tier == nil {
|
||||||
return errHTTPUnauthorized
|
return errHTTPUnauthorized
|
||||||
} else if err := s.userManager.CheckAllowAccess(u.Name, req.Topic); err != nil {
|
} else if err := s.userManager.AllowReservation(u.Name, req.Topic); err != nil {
|
||||||
return errHTTPConflictTopicReserved
|
return errHTTPConflictTopicReserved
|
||||||
} else if u.IsUser() {
|
} else if u.IsUser() {
|
||||||
hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
|
hasReservation, err := s.userManager.HasReservation(u.Name, req.Topic)
|
||||||
|
|
|
@ -1017,9 +1017,9 @@ func (a *Manager) checkReservationsLimit(username string, reservationsLimit int6
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckAllowAccess tests if a user may create an access control entry for the given topic.
|
// AllowReservation tests if a user may create an access control entry for the given topic.
|
||||||
// If there are any ACL entries that are not owned by the user, an error is returned.
|
// If there are any ACL entries that are not owned by the user, an error is returned.
|
||||||
func (a *Manager) CheckAllowAccess(username string, topic string) error {
|
func (a *Manager) AllowReservation(username string, topic string) error {
|
||||||
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
|
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
|
||||||
return ErrInvalidArgument
|
return ErrInvalidArgument
|
||||||
}
|
}
|
||||||
|
|
|
@ -106,6 +106,30 @@ func TestManager_AddUser_Timing(t *testing.T) {
|
||||||
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
|
require.GreaterOrEqual(t, time.Now().UnixMilli()-start, minBcryptTimingMillis)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManager_AddUser_And_Query(t *testing.T) {
|
||||||
|
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
||||||
|
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
||||||
|
require.Nil(t, a.ChangeBilling("user", &Billing{
|
||||||
|
StripeCustomerID: "acct_123",
|
||||||
|
StripeSubscriptionID: "sub_123",
|
||||||
|
StripeSubscriptionStatus: "active",
|
||||||
|
StripeSubscriptionPaidUntil: time.Now().Add(time.Hour),
|
||||||
|
StripeSubscriptionCancelAt: time.Unix(0, 0),
|
||||||
|
}))
|
||||||
|
|
||||||
|
u, err := a.User("user")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, "user", u.Name)
|
||||||
|
|
||||||
|
u2, err := a.UserByID(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, u.Name, u2.Name)
|
||||||
|
|
||||||
|
u3, err := a.UserByStripeCustomer("acct_123")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, u.ID, u3.ID)
|
||||||
|
}
|
||||||
|
|
||||||
func TestManager_Authenticate_Timing(t *testing.T) {
|
func TestManager_Authenticate_Timing(t *testing.T) {
|
||||||
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
a := newTestManagerFromFile(t, filepath.Join(t.TempDir(), "user.db"), "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
|
||||||
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
require.Nil(t, a.AddUser("user", "pass", RoleAdmin))
|
||||||
|
@ -311,6 +335,7 @@ func TestManager_ChangeRole(t *testing.T) {
|
||||||
|
|
||||||
func TestManager_Reservations(t *testing.T) {
|
func TestManager_Reservations(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
a := newTestManager(t, PermissionDenyAll)
|
||||||
|
require.Nil(t, a.AddUser("phil", "phil", RoleUser))
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
|
require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll))
|
||||||
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
|
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
|
||||||
|
@ -329,6 +354,32 @@ func TestManager_Reservations(t *testing.T) {
|
||||||
Owner: PermissionReadWrite,
|
Owner: PermissionReadWrite,
|
||||||
Everyone: PermissionDenyAll,
|
Everyone: PermissionDenyAll,
|
||||||
}, reservations[1])
|
}, reservations[1])
|
||||||
|
|
||||||
|
b, err := a.HasReservation("ben", "readme")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, b)
|
||||||
|
|
||||||
|
b, err = a.HasReservation("notben", "readme")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.False(t, b)
|
||||||
|
|
||||||
|
b, err = a.HasReservation("ben", "something-else")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.False(t, b)
|
||||||
|
|
||||||
|
count, err := a.ReservationsCount("ben")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(2), count)
|
||||||
|
|
||||||
|
count, err = a.ReservationsCount("phil")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), count)
|
||||||
|
|
||||||
|
err = a.AllowReservation("phil", "readme")
|
||||||
|
require.Equal(t, errTopicOwnedByOthers, err)
|
||||||
|
|
||||||
|
err = a.AllowReservation("phil", "not-reserved")
|
||||||
|
require.Nil(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
func TestManager_ChangeRoleFromTierUserToAdmin(t *testing.T) {
|
||||||
|
@ -414,11 +465,24 @@ func TestManager_Token_Valid(t *testing.T) {
|
||||||
require.Equal(t, token.Value, token2.Value)
|
require.Equal(t, token.Value, token2.Value)
|
||||||
require.Equal(t, "some label", token2.Label)
|
require.Equal(t, "some label", token2.Label)
|
||||||
|
|
||||||
|
tokens, err := a.Tokens(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 1, len(tokens))
|
||||||
|
require.Equal(t, "some label", tokens[0].Label)
|
||||||
|
|
||||||
|
tokens, err = a.Tokens("u_notauser")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(tokens))
|
||||||
|
|
||||||
// Remove token and auth again
|
// Remove token and auth again
|
||||||
require.Nil(t, a.RemoveToken(u2.ID, u2.Token))
|
require.Nil(t, a.RemoveToken(u2.ID, u2.Token))
|
||||||
u3, err := a.AuthenticateToken(token.Value)
|
u3, err := a.AuthenticateToken(token.Value)
|
||||||
require.Equal(t, ErrUnauthenticated, err)
|
require.Equal(t, ErrUnauthenticated, err)
|
||||||
require.Nil(t, u3)
|
require.Nil(t, u3)
|
||||||
|
|
||||||
|
tokens, err = a.Tokens(u.ID)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, 0, len(tokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_Token_Invalid(t *testing.T) {
|
func TestManager_Token_Invalid(t *testing.T) {
|
||||||
|
@ -434,6 +498,12 @@ func TestManager_Token_Invalid(t *testing.T) {
|
||||||
require.Equal(t, ErrUnauthenticated, err)
|
require.Equal(t, ErrUnauthenticated, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestManager_Token_NotFound(t *testing.T) {
|
||||||
|
a := newTestManager(t, PermissionDenyAll)
|
||||||
|
_, err := a.Token("u_bla", "notfound")
|
||||||
|
require.Equal(t, ErrTokenNotFound, err)
|
||||||
|
}
|
||||||
|
|
||||||
func TestManager_Token_Expire(t *testing.T) {
|
func TestManager_Token_Expire(t *testing.T) {
|
||||||
a := newTestManager(t, PermissionDenyAll)
|
a := newTestManager(t, PermissionDenyAll)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
|
@ -552,7 +622,7 @@ func TestManager_Token_MaxCount_AutoDelete(t *testing.T) {
|
||||||
require.Equal(t, 20, count)
|
require.Equal(t, 20, count)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_EnqueueStats(t *testing.T) {
|
func TestManager_EnqueueStats_ResetStats(t *testing.T) {
|
||||||
a, err := NewManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, bcrypt.MinCost, 1500*time.Millisecond)
|
a, err := NewManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, bcrypt.MinCost, 1500*time.Millisecond)
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
|
@ -580,6 +650,51 @@ func TestManager_EnqueueStats(t *testing.T) {
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
require.Equal(t, int64(11), u.Stats.Messages)
|
require.Equal(t, int64(11), u.Stats.Messages)
|
||||||
require.Equal(t, int64(2), u.Stats.Emails)
|
require.Equal(t, int64(2), u.Stats.Emails)
|
||||||
|
|
||||||
|
// Now reset stats (enqueued stats will be thrown out)
|
||||||
|
a.EnqueueUserStats(u.ID, &Stats{
|
||||||
|
Messages: 99,
|
||||||
|
Emails: 23,
|
||||||
|
})
|
||||||
|
require.Nil(t, a.ResetStats())
|
||||||
|
|
||||||
|
u, err = a.User("ben")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, int64(0), u.Stats.Messages)
|
||||||
|
require.Equal(t, int64(0), u.Stats.Emails)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_EnqueueTokenUpdate(t *testing.T) {
|
||||||
|
a, err := NewManager(filepath.Join(t.TempDir(), "db"), "", PermissionReadWrite, bcrypt.MinCost, 500*time.Millisecond)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Nil(t, a.AddUser("ben", "ben", RoleUser))
|
||||||
|
|
||||||
|
// Create user and token
|
||||||
|
u, err := a.User("ben")
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
token, err := a.CreateToken(u.ID, "", time.Now().Add(time.Hour), netip.IPv4Unspecified())
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
// Queue token update
|
||||||
|
a.EnqueueTokenUpdate(token.Value, &TokenUpdate{
|
||||||
|
LastAccess: time.Unix(111, 0).UTC(),
|
||||||
|
LastOrigin: netip.MustParseAddr("1.2.3.3"),
|
||||||
|
})
|
||||||
|
|
||||||
|
// Token has not changed yet.
|
||||||
|
token2, err := a.Token(u.ID, token.Value)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, token.LastAccess.Unix(), token2.LastAccess.Unix())
|
||||||
|
require.Equal(t, token.LastOrigin, token2.LastOrigin)
|
||||||
|
|
||||||
|
// After a second or so they should be persisted
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
token3, err := a.Token(u.ID, token.Value)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, time.Unix(111, 0).UTC().Unix(), token3.LastAccess.Unix())
|
||||||
|
require.Equal(t, netip.MustParseAddr("1.2.3.3"), token3.LastOrigin)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ChangeSettings(t *testing.T) {
|
func TestManager_ChangeSettings(t *testing.T) {
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
"heckel.io/ntfy/log"
|
"heckel.io/ntfy/log"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -97,7 +98,7 @@ type Tier struct {
|
||||||
func (t *Tier) Context() log.Context {
|
func (t *Tier) Context() log.Context {
|
||||||
return log.Context{
|
return log.Context{
|
||||||
"tier_id": t.ID,
|
"tier_id": t.ID,
|
||||||
"tier_name": t.Name,
|
"tier_code": t.Code,
|
||||||
"stripe_price_id": t.StripePriceID,
|
"stripe_price_id": t.StripePriceID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -170,7 +171,7 @@ func NewPermission(read, write bool) Permission {
|
||||||
|
|
||||||
// ParsePermission parses the string representation and returns a Permission
|
// ParsePermission parses the string representation and returns a Permission
|
||||||
func ParsePermission(s string) (Permission, error) {
|
func ParsePermission(s string) (Permission, error) {
|
||||||
switch s {
|
switch strings.ToLower(s) {
|
||||||
case "read-write", "rw":
|
case "read-write", "rw":
|
||||||
return NewPermission(true, true), nil
|
return NewPermission(true, true), nil
|
||||||
case "read-only", "read", "ro":
|
case "read-only", "read", "ro":
|
||||||
|
|
|
@ -10,4 +10,51 @@ func TestPermission(t *testing.T) {
|
||||||
require.Equal(t, PermissionRead, NewPermission(true, false))
|
require.Equal(t, PermissionRead, NewPermission(true, false))
|
||||||
require.Equal(t, PermissionWrite, NewPermission(false, true))
|
require.Equal(t, PermissionWrite, NewPermission(false, true))
|
||||||
require.Equal(t, PermissionDenyAll, NewPermission(false, false))
|
require.Equal(t, PermissionDenyAll, NewPermission(false, false))
|
||||||
|
require.True(t, PermissionReadWrite.IsReadWrite())
|
||||||
|
require.True(t, PermissionReadWrite.IsRead())
|
||||||
|
require.True(t, PermissionReadWrite.IsWrite())
|
||||||
|
require.True(t, PermissionRead.IsRead())
|
||||||
|
require.True(t, PermissionWrite.IsWrite())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePermission(t *testing.T) {
|
||||||
|
_, err := ParsePermission("no")
|
||||||
|
require.NotNil(t, err)
|
||||||
|
|
||||||
|
p, err := ParsePermission("read-write")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, PermissionReadWrite, p)
|
||||||
|
|
||||||
|
p, err = ParsePermission("rw")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, PermissionReadWrite, p)
|
||||||
|
|
||||||
|
p, err = ParsePermission("read-only")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, PermissionRead, p)
|
||||||
|
|
||||||
|
p, err = ParsePermission("WRITE")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, PermissionWrite, p)
|
||||||
|
|
||||||
|
p, err = ParsePermission("deny-all")
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.Equal(t, PermissionDenyAll, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllowedTier(t *testing.T) {
|
||||||
|
require.False(t, AllowedTier(" no"))
|
||||||
|
require.True(t, AllowedTier("yes"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTierContext(t *testing.T) {
|
||||||
|
tier := &Tier{
|
||||||
|
ID: "ti_abc",
|
||||||
|
Code: "pro",
|
||||||
|
StripePriceID: "price_123",
|
||||||
|
}
|
||||||
|
context := tier.Context()
|
||||||
|
require.Equal(t, "ti_abc", context["tier_id"])
|
||||||
|
require.Equal(t, "pro", context["tier_code"])
|
||||||
|
require.Equal(t, "price_123", context["stripe_price_id"])
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue