diff --git a/backend/app/api/routes.go b/backend/app/api/routes.go index cab1a14..9049bb5 100644 --- a/backend/app/api/routes.go +++ b/backend/app/api/routes.go @@ -13,6 +13,7 @@ import ( "github.com/hay-kot/homebox/backend/app/api/handlers/debughandlers" v1 "github.com/hay-kot/homebox/backend/app/api/handlers/v1" _ "github.com/hay-kot/homebox/backend/app/api/static/docs" + "github.com/hay-kot/homebox/backend/internal/data/ent/authroles" "github.com/hay-kot/homebox/backend/internal/data/repo" "github.com/hay-kot/homebox/backend/pkgs/server" httpSwagger "github.com/swaggo/http-swagger" // http-swagger middleware @@ -68,45 +69,55 @@ func (a *app) mountRoutes(repos *repo.AllRepos) { // and also needs to be outside of the `auth` middleware. a.server.Get(v1Base("/items/{id}/attachments/download"), v1Ctrl.HandleItemAttachmentDownload()) - a.server.Get(v1Base("/users/self"), v1Ctrl.HandleUserSelf(), a.mwAuthToken) - a.server.Put(v1Base("/users/self"), v1Ctrl.HandleUserSelfUpdate(), a.mwAuthToken) - a.server.Delete(v1Base("/users/self"), v1Ctrl.HandleUserSelfDelete(), a.mwAuthToken) - a.server.Post(v1Base("/users/logout"), v1Ctrl.HandleAuthLogout(), a.mwAuthToken) - a.server.Get(v1Base("/users/refresh"), v1Ctrl.HandleAuthRefresh(), a.mwAuthToken) - a.server.Put(v1Base("/users/self/change-password"), v1Ctrl.HandleUserSelfChangePassword(), a.mwAuthToken) + userMW := []server.Middleware{ + a.mwAuthToken, + a.mwRoles(RoleModeOr, authroles.RoleUser.String()), + } - a.server.Post(v1Base("/groups/invitations"), v1Ctrl.HandleGroupInvitationsCreate(), a.mwAuthToken) - a.server.Get(v1Base("/groups/statistics"), v1Ctrl.HandleGroupStatistics(), a.mwAuthToken) + a.server.Get(v1Base("/users/self"), v1Ctrl.HandleUserSelf(), userMW...) + a.server.Put(v1Base("/users/self"), v1Ctrl.HandleUserSelfUpdate(), userMW...) + a.server.Delete(v1Base("/users/self"), v1Ctrl.HandleUserSelfDelete(), userMW...) + a.server.Post(v1Base("/users/logout"), v1Ctrl.HandleAuthLogout(), userMW...) + a.server.Get(v1Base("/users/refresh"), v1Ctrl.HandleAuthRefresh(), userMW...) + a.server.Put(v1Base("/users/self/change-password"), v1Ctrl.HandleUserSelfChangePassword(), userMW...) + + a.server.Post(v1Base("/groups/invitations"), v1Ctrl.HandleGroupInvitationsCreate(), userMW...) + a.server.Get(v1Base("/groups/statistics"), v1Ctrl.HandleGroupStatistics(), userMW...) // TODO: I don't like /groups being the URL for users - a.server.Get(v1Base("/groups"), v1Ctrl.HandleGroupGet(), a.mwAuthToken) - a.server.Put(v1Base("/groups"), v1Ctrl.HandleGroupUpdate(), a.mwAuthToken) + a.server.Get(v1Base("/groups"), v1Ctrl.HandleGroupGet(), userMW...) + a.server.Put(v1Base("/groups"), v1Ctrl.HandleGroupUpdate(), userMW...) - a.server.Post(v1Base("/actions/ensure-asset-ids"), v1Ctrl.HandleEnsureAssetID(), a.mwAuthToken) + a.server.Post(v1Base("/actions/ensure-asset-ids"), v1Ctrl.HandleEnsureAssetID(), userMW...) - a.server.Get(v1Base("/locations"), v1Ctrl.HandleLocationGetAll(), a.mwAuthToken) - a.server.Post(v1Base("/locations"), v1Ctrl.HandleLocationCreate(), a.mwAuthToken) - a.server.Get(v1Base("/locations/{id}"), v1Ctrl.HandleLocationGet(), a.mwAuthToken) - a.server.Put(v1Base("/locations/{id}"), v1Ctrl.HandleLocationUpdate(), a.mwAuthToken) - a.server.Delete(v1Base("/locations/{id}"), v1Ctrl.HandleLocationDelete(), a.mwAuthToken) + a.server.Get(v1Base("/locations"), v1Ctrl.HandleLocationGetAll(), userMW...) + a.server.Post(v1Base("/locations"), v1Ctrl.HandleLocationCreate(), userMW...) + a.server.Get(v1Base("/locations/{id}"), v1Ctrl.HandleLocationGet(), userMW...) + a.server.Put(v1Base("/locations/{id}"), v1Ctrl.HandleLocationUpdate(), userMW...) + a.server.Delete(v1Base("/locations/{id}"), v1Ctrl.HandleLocationDelete(), userMW...) - a.server.Get(v1Base("/labels"), v1Ctrl.HandleLabelsGetAll(), a.mwAuthToken) - a.server.Post(v1Base("/labels"), v1Ctrl.HandleLabelsCreate(), a.mwAuthToken) - a.server.Get(v1Base("/labels/{id}"), v1Ctrl.HandleLabelGet(), a.mwAuthToken) - a.server.Put(v1Base("/labels/{id}"), v1Ctrl.HandleLabelUpdate(), a.mwAuthToken) - a.server.Delete(v1Base("/labels/{id}"), v1Ctrl.HandleLabelDelete(), a.mwAuthToken) + a.server.Get(v1Base("/labels"), v1Ctrl.HandleLabelsGetAll(), userMW...) + a.server.Post(v1Base("/labels"), v1Ctrl.HandleLabelsCreate(), userMW...) + a.server.Get(v1Base("/labels/{id}"), v1Ctrl.HandleLabelGet(), userMW...) + a.server.Put(v1Base("/labels/{id}"), v1Ctrl.HandleLabelUpdate(), userMW...) + a.server.Delete(v1Base("/labels/{id}"), v1Ctrl.HandleLabelDelete(), userMW...) - a.server.Get(v1Base("/items"), v1Ctrl.HandleItemsGetAll(), a.mwAuthToken) - a.server.Post(v1Base("/items/import"), v1Ctrl.HandleItemsImport(), a.mwAuthToken) - a.server.Post(v1Base("/items"), v1Ctrl.HandleItemsCreate(), a.mwAuthToken) - a.server.Get(v1Base("/items/{id}"), v1Ctrl.HandleItemGet(), a.mwAuthToken) - a.server.Put(v1Base("/items/{id}"), v1Ctrl.HandleItemUpdate(), a.mwAuthToken) - a.server.Delete(v1Base("/items/{id}"), v1Ctrl.HandleItemDelete(), a.mwAuthToken) + a.server.Get(v1Base("/items"), v1Ctrl.HandleItemsGetAll(), userMW...) + a.server.Post(v1Base("/items/import"), v1Ctrl.HandleItemsImport(), userMW...) + a.server.Post(v1Base("/items"), v1Ctrl.HandleItemsCreate(), userMW...) + a.server.Get(v1Base("/items/{id}"), v1Ctrl.HandleItemGet(), userMW...) + a.server.Put(v1Base("/items/{id}"), v1Ctrl.HandleItemUpdate(), userMW...) + a.server.Delete(v1Base("/items/{id}"), v1Ctrl.HandleItemDelete(), userMW...) - a.server.Post(v1Base("/items/{id}/attachments"), v1Ctrl.HandleItemAttachmentCreate(), a.mwAuthToken) - a.server.Get(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentToken(), a.mwAuthToken) - a.server.Put(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentUpdate(), a.mwAuthToken) - a.server.Delete(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentDelete(), a.mwAuthToken) + a.server.Post(v1Base("/items/{id}/attachments"), v1Ctrl.HandleItemAttachmentCreate(), userMW...) + a.server.Put(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentUpdate(), userMW...) + a.server.Delete(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentDelete(), userMW...) + + a.server.Get( + v1Base("/items/{id}/attachments/{attachment_id}"), + v1Ctrl.HandleItemAttachmentGet(), + a.mwAuthToken, a.mwRoles(RoleModeOr, authroles.RoleUser.String(), authroles.RoleAttachments.String()), + ) a.server.NotFound(notFoundHandler()) } diff --git a/backend/internal/data/repo/repo_tokens.go b/backend/internal/data/repo/repo_tokens.go index 7d9115b..5820227 100644 --- a/backend/internal/data/repo/repo_tokens.go +++ b/backend/internal/data/repo/repo_tokens.go @@ -6,7 +6,10 @@ import ( "github.com/google/uuid" "github.com/hay-kot/homebox/backend/internal/data/ent" + "github.com/hay-kot/homebox/backend/internal/data/ent/authroles" "github.com/hay-kot/homebox/backend/internal/data/ent/authtokens" + "github.com/hay-kot/homebox/backend/pkgs/hasher" + "github.com/hay-kot/homebox/backend/pkgs/set" ) type TokenRepository struct { @@ -47,9 +50,31 @@ func (r *TokenRepository) GetUserFromToken(ctx context.Context, token []byte) (U return mapUserOut(user), nil } -// Creates a token for a user -func (r *TokenRepository) CreateToken(ctx context.Context, createToken UserAuthTokenCreate) (UserAuthToken, error) { +func (r *TokenRepository) GetRoles(ctx context.Context, token string) (*set.Set[string], error) { + tokenHash := hasher.HashToken(token) + roles, err := r.db.AuthRoles. + Query(). + Where(authroles.HasTokenWith( + authtokens.Token(tokenHash), + )). + All(ctx) + + if err != nil { + return nil, err + } + + roleSet := set.Make[string](len(roles)) + + for _, role := range roles { + roleSet.Insert(role.Role.String()) + } + + return &roleSet, nil +} + +// Creates a token for a user +func (r *TokenRepository) CreateToken(ctx context.Context, createToken UserAuthTokenCreate, roles ...authroles.Role) (UserAuthToken, error) { dbToken, err := r.db.AuthTokens.Create(). SetToken(createToken.TokenHash). SetUserID(createToken.UserID). @@ -60,6 +85,17 @@ func (r *TokenRepository) CreateToken(ctx context.Context, createToken UserAuthT return UserAuthToken{}, err } + for _, role := range roles { + _, err := r.db.AuthRoles.Create(). + SetRole(role). + SetToken(dbToken). + Save(ctx) + + if err != nil { + return UserAuthToken{}, err + } + } + return UserAuthToken{ UserAuthTokenCreate: UserAuthTokenCreate{ TokenHash: dbToken.Token, diff --git a/backend/pkgs/set/set.go b/backend/pkgs/set/set.go index f2ffecc..b0918bb 100644 --- a/backend/pkgs/set/set.go +++ b/backend/pkgs/set/set.go @@ -8,6 +8,12 @@ type Set[T key] struct { mp map[T]struct{} } +func Make[T key](size int) Set[T] { + return Set[T]{ + mp: make(map[T]struct{}, size), + } +} + func New[T key](v ...T) Set[T] { mp := make(map[T]struct{}, len(v))