fix: inaccruate 401 & sql busy error (#679)

* fix inaccruate 401 error on SQL db error

* init golangci-lint config

* linter autofix

* testify auto fixes

* fix sqlite busy errors

* fix naming

* more linter errors

* fix rest of linter issues

Former-commit-id: e8449b3a73
This commit is contained in:
Hayden 2024-01-04 11:55:26 -06:00 committed by GitHub
parent 5e83b28ff5
commit 03df23d97c
62 changed files with 389 additions and 292 deletions

View file

@ -16,7 +16,7 @@
"editor.formatOnSave": false, "editor.formatOnSave": false,
"editor.defaultFormatter": "dbaeumer.vscode-eslint", "editor.defaultFormatter": "dbaeumer.vscode-eslint",
"editor.codeActionsOnSave": { "editor.codeActionsOnSave": {
"source.fixAll.eslint": true "source.fixAll.eslint": "explicit"
}, },
"[typescript]": { "[typescript]": {
"editor.defaultFormatter": "dbaeumer.vscode-eslint" "editor.defaultFormatter": "dbaeumer.vscode-eslint"

View file

@ -32,7 +32,7 @@ FROM alpine:latest
ENV HBOX_MODE=production ENV HBOX_MODE=production
ENV HBOX_STORAGE_DATA=/data/ ENV HBOX_STORAGE_DATA=/data/
ENV HBOX_STORAGE_SQLITE_URL=/data/homebox.db?_fk=1 ENV HBOX_STORAGE_SQLITE_URL=/data/homebox.db?_pragma=busy_timeout=2000&_pragma=journal_mode=WAL&_fk=1
RUN apk --no-cache add ca-certificates RUN apk --no-cache add ca-certificates
RUN mkdir /app RUN mkdir /app

73
backend/.golangci.yml Normal file
View file

@ -0,0 +1,73 @@
run:
timeout: 10m
skip-dirs:
- internal/data/ent.*
linters-settings:
goconst:
min-len: 5
min-occurrences: 5
exhaustive:
default-signifies-exhaustive: true
revive:
ignore-generated-header: false
severity: warning
confidence: 3
depguard:
rules:
main:
deny:
- pkg: io/util
desc: |
Deprecated: As of Go 1.16, the same functionality is now provided by
package io or package os, and those implementations should be
preferred in new code. See the specific function documentation for
details.
gocritic:
enabled-checks:
- ruleguard
testifylint:
enable-all: true
tagalign:
order:
- json
- schema
- yaml
- yml
- toml
- validate
linters:
disable-all: true
enable:
- asciicheck
- bodyclose
- depguard
- dogsled
- errcheck
- errorlint
- exhaustive
- exportloopref
- gochecknoinits
- goconst
- gocritic
- gocyclo
- goprintffuncname
- gosimple
- govet
- ineffassign
- misspell
- nakedret
- revive
- staticcheck
- stylecheck
- tagalign
- testifylint
- typecheck
- typecheck
- unconvert
- unused
- whitespace
- zerologlint
- sqlclosecheck
issues:
exclude-use-default: false
fix: true

View file

@ -15,7 +15,7 @@ func (a *app) SetupDemo() {
,Office,IOT;Home Assistant; Z-Wave,1,Zooz 110v Power Switch,"Zooz Z-Wave Plus Power Switch ZEN15 for 110V AC Units, Sump Pumps, Humidifiers, and More",,,ZEN15,Zooz,,Amazon,39.95,10/13/2021,,,,,,, ,Office,IOT;Home Assistant; Z-Wave,1,Zooz 110v Power Switch,"Zooz Z-Wave Plus Power Switch ZEN15 for 110V AC Units, Sump Pumps, Humidifiers, and More",,,ZEN15,Zooz,,Amazon,39.95,10/13/2021,,,,,,,
,Downstairs,IOT;Home Assistant; Z-Wave,1,Ecolink Z-Wave PIR Motion Sensor,"Ecolink Z-Wave PIR Motion Detector Pet Immune, White (PIRZWAVE2.5-ECO)",,,PIRZWAVE2.5-ECO,Ecolink,,Amazon,35.58,10/21/2020,,,,,,, ,Downstairs,IOT;Home Assistant; Z-Wave,1,Ecolink Z-Wave PIR Motion Sensor,"Ecolink Z-Wave PIR Motion Detector Pet Immune, White (PIRZWAVE2.5-ECO)",,,PIRZWAVE2.5-ECO,Ecolink,,Amazon,35.58,10/21/2020,,,,,,,
,Entry,IOT;Home Assistant; Z-Wave,1,Yale Security Touchscreen Deadbolt,"Yale Security YRD226-ZW2-619 YRD226ZW2619 Touchscreen Deadbolt, Satin Nickel",,,YRD226ZW2619,Yale,,Amazon,120.39,10/14/2020,,,,,,, ,Entry,IOT;Home Assistant; Z-Wave,1,Yale Security Touchscreen Deadbolt,"Yale Security YRD226-ZW2-619 YRD226ZW2619 Touchscreen Deadbolt, Satin Nickel",,,YRD226ZW2619,Yale,,Amazon,120.39,10/14/2020,,,,,,,
,Kitchen,IOT;Home Assistant; Z-Wave,1,Smart Rocker Light Dimmer,"UltraPro Z-Wave Smart Rocker Light Dimmer with QuickFit and SimpleWire, 3-Way Ready, Compatible with Alexa, Google Assistant, ZWave Hub Required, Repeater/Range Extender, White Paddle Only, 39351",,,39351,Honeywell,,Amazon,65.98,09/30/0202,,,,,,, ,Kitchen,IOT;Home Assistant; Z-Wave,1,Smart Rocker Light Dimmer,"UltraPro Z-Wave Smart Rocker Light Dimmer with QuickFit and SimpleWire, 3-Way Ready, Compatible with Alexa, Google Assistant, ZWave Hub Required, Repeater/Range Extender, White Paddle Only, 39351",,,39351,Honeywell,,Amazon,65.98,09/30/0202,,,,,,,
` `
registration := services.UserRegistration{ registration := services.UserRegistration{

View file

@ -1,3 +1,4 @@
// Package debughandlers provides handlers for debugging.
package debughandlers package debughandlers
import ( import (

View file

@ -1,3 +1,4 @@
// Package v1 provides the API handlers for version 1 of the API.
package v1 package v1
import ( import (
@ -74,7 +75,7 @@ type (
BuildTime string `json:"buildTime"` BuildTime string `json:"buildTime"`
} }
ApiSummary struct { APISummary struct {
Healthy bool `json:"health"` Healthy bool `json:"health"`
Versions []string `json:"versions"` Versions []string `json:"versions"`
Title string `json:"title"` Title string `json:"title"`
@ -85,7 +86,7 @@ type (
} }
) )
func BaseUrlFunc(prefix string) func(s string) string { func BaseURLFunc(prefix string) func(s string) string {
return func(s string) string { return func(s string) string {
return prefix + "/v1" + s return prefix + "/v1" + s
} }
@ -115,7 +116,7 @@ func NewControllerV1(svc *services.AllServices, repos *repo.AllRepos, bus *event
// @Router /v1/status [GET] // @Router /v1/status [GET]
func (ctrl *V1Controller) HandleBase(ready ReadyFunc, build Build) errchain.HandlerFunc { func (ctrl *V1Controller) HandleBase(ready ReadyFunc, build Build) errchain.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error {
return server.JSON(w, http.StatusOK, ApiSummary{ return server.JSON(w, http.StatusOK, APISummary{
Healthy: ready(), Healthy: ready(),
Title: "Homebox", Title: "Homebox",
Message: "Track, Manage, and Organize your Things", Message: "Track, Manage, and Organize your Things",

View file

@ -27,10 +27,10 @@ import (
func (ctrl *V1Controller) HandleAssetGet() errchain.HandlerFunc { func (ctrl *V1Controller) HandleAssetGet() errchain.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) error { return func(w http.ResponseWriter, r *http.Request) error {
ctx := services.NewContext(r.Context()) ctx := services.NewContext(r.Context())
assetIdParam := chi.URLParam(r, "id") assetIDParam := chi.URLParam(r, "id")
assetIdParam = strings.ReplaceAll(assetIdParam, "-", "") // Remove dashes assetIDParam = strings.ReplaceAll(assetIDParam, "-", "") // Remove dashes
// Convert the asset ID to an int64 // Convert the asset ID to an int64
assetId, err := strconv.ParseInt(assetIdParam, 10, 64) assetID, err := strconv.ParseInt(assetIDParam, 10, 64)
if err != nil { if err != nil {
return err return err
} }
@ -52,7 +52,7 @@ func (ctrl *V1Controller) HandleAssetGet() errchain.HandlerFunc {
} }
} }
items, err := ctrl.repo.Items.QueryByAssetID(r.Context(), ctx.GID, repo.AssetID(assetId), int(page), int(pageSize)) items, err := ctrl.repo.Items.QueryByAssetID(r.Context(), ctx.GID, repo.AssetID(assetID), int(page), int(pageSize))
if err != nil { if err != nil {
log.Err(err).Msg("failed to get item") log.Err(err).Msg("failed to get item")
return validate.NewRequestError(err, http.StatusInternalServerError) return validate.NewRequestError(err, http.StatusInternalServerError)

View file

@ -153,7 +153,7 @@ func (ctrl *V1Controller) HandleAuthLogout() errchain.HandlerFunc {
} }
} }
// HandleAuthLogout godoc // HandleAuthRefresh godoc
// //
// @Summary User Token Refresh // @Summary User Token Refresh
// @Description handleAuthRefresh returns a handler that will issue a new token from an existing token. // @Description handleAuthRefresh returns a handler that will issue a new token from an existing token.

View file

@ -233,7 +233,6 @@ func (ctrl *V1Controller) HandleGetAllCustomFieldValues() errchain.HandlerFunc {
} }
return adapters.Query(fn, http.StatusOK) return adapters.Query(fn, http.StatusOK)
} }
// HandleItemsImport godocs // HandleItemsImport godocs

View file

@ -39,7 +39,6 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() errchain.HandlerFunc {
if err != nil { if err != nil {
log.Err(err).Msg("failed to parse multipart form") log.Err(err).Msg("failed to parse multipart form")
return validate.NewRequestError(errors.New("failed to parse multipart form"), http.StatusBadRequest) return validate.NewRequestError(errors.New("failed to parse multipart form"), http.StatusBadRequest)
} }
errs := validate.NewFieldErrors() errs := validate.NewFieldErrors()

View file

@ -10,7 +10,7 @@ import (
"github.com/hay-kot/httpkit/errchain" "github.com/hay-kot/httpkit/errchain"
) )
// HandleLocationTreeQuery // HandleLocationTreeQuery godoc
// //
// @Summary Get Locations Tree // @Summary Get Locations Tree
// @Tags Locations // @Tags Locations
@ -28,7 +28,7 @@ func (ctrl *V1Controller) HandleLocationTreeQuery() errchain.HandlerFunc {
return adapters.Query(fn, http.StatusOK) return adapters.Query(fn, http.StatusOK)
} }
// HandleLocationGetAll // HandleLocationGetAll godoc
// //
// @Summary Get All Locations // @Summary Get All Locations
// @Tags Locations // @Tags Locations
@ -46,7 +46,7 @@ func (ctrl *V1Controller) HandleLocationGetAll() errchain.HandlerFunc {
return adapters.Query(fn, http.StatusOK) return adapters.Query(fn, http.StatusOK)
} }
// HandleLocationCreate // HandleLocationCreate godoc
// //
// @Summary Create Location // @Summary Create Location
// @Tags Locations // @Tags Locations
@ -64,7 +64,7 @@ func (ctrl *V1Controller) HandleLocationCreate() errchain.HandlerFunc {
return adapters.Action(fn, http.StatusCreated) return adapters.Action(fn, http.StatusCreated)
} }
// HandleLocationDelete // HandleLocationDelete godoc
// //
// @Summary Delete Location // @Summary Delete Location
// @Tags Locations // @Tags Locations
@ -83,7 +83,7 @@ func (ctrl *V1Controller) HandleLocationDelete() errchain.HandlerFunc {
return adapters.CommandID("id", fn, http.StatusNoContent) return adapters.CommandID("id", fn, http.StatusNoContent)
} }
// HandleLocationGet // HandleLocationGet godoc
// //
// @Summary Get Location // @Summary Get Location
// @Tags Locations // @Tags Locations
@ -101,7 +101,7 @@ func (ctrl *V1Controller) HandleLocationGet() errchain.HandlerFunc {
return adapters.CommandID("id", fn, http.StatusOK) return adapters.CommandID("id", fn, http.StatusOK)
} }
// HandleLocationUpdate // HandleLocationUpdate godoc
// //
// @Summary Update Location // @Summary Update Location
// @Tags Locations // @Tags Locations

View file

@ -10,7 +10,7 @@ import (
"github.com/hay-kot/httpkit/errchain" "github.com/hay-kot/httpkit/errchain"
) )
// HandleMaintenanceGetLog godoc // HandleMaintenanceLogGet godoc
// //
// @Summary Get Maintenance Log // @Summary Get Maintenance Log
// @Tags Maintenance // @Tags Maintenance

View file

@ -12,7 +12,7 @@ import (
"github.com/hay-kot/httpkit/server" "github.com/hay-kot/httpkit/server"
) )
// HandleGroupGet godoc // HandleGroupStatisticsLocations godoc
// //
// @Summary Get Location Statistics // @Summary Get Location Statistics
// @Tags Statistics // @Tags Statistics

View file

@ -78,12 +78,12 @@ func run(cfg *config.Config) error {
log.Fatal().Err(err).Msg("failed to create data directory") log.Fatal().Err(err).Msg("failed to create data directory")
} }
c, err := ent.Open("sqlite3", cfg.Storage.SqliteUrl) c, err := ent.Open("sqlite3", cfg.Storage.SqliteURL)
if err != nil { if err != nil {
log.Fatal(). log.Fatal().
Err(err). Err(err).
Str("driver", "sqlite"). Str("driver", "sqlite").
Str("url", cfg.Storage.SqliteUrl). Str("url", cfg.Storage.SqliteURL).
Msg("failed opening connection to sqlite") Msg("failed opening connection to sqlite")
} }
defer func(c *ent.Client) { defer func(c *ent.Client) {
@ -116,7 +116,7 @@ func run(cfg *config.Config) error {
log.Fatal(). log.Fatal().
Err(err). Err(err).
Str("driver", "sqlite"). Str("driver", "sqlite").
Str("url", cfg.Storage.SqliteUrl). Str("url", cfg.Storage.SqliteURL).
Msg("failed creating schema resources") Msg("failed creating schema resources")
} }

View file

@ -9,6 +9,7 @@ import (
v1 "github.com/hay-kot/homebox/backend/app/api/handlers/v1" v1 "github.com/hay-kot/homebox/backend/app/api/handlers/v1"
"github.com/hay-kot/homebox/backend/internal/core/services" "github.com/hay-kot/homebox/backend/internal/core/services"
"github.com/hay-kot/homebox/backend/internal/data/ent"
"github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/httpkit/errchain" "github.com/hay-kot/httpkit/errchain"
) )
@ -130,7 +131,7 @@ func (a *app) mwAuthToken(next errchain.Handler) errchain.Handler {
} }
if requestToken == "" { if requestToken == "" {
return validate.NewRequestError(errors.New("Authorization header or query is required"), http.StatusUnauthorized) return validate.NewRequestError(errors.New("authorization header or query is required"), http.StatusUnauthorized)
} }
requestToken = strings.TrimPrefix(requestToken, "Bearer ") requestToken = strings.TrimPrefix(requestToken, "Bearer ")
@ -140,7 +141,11 @@ func (a *app) mwAuthToken(next errchain.Handler) errchain.Handler {
usr, err := a.services.User.GetSelf(r.Context(), requestToken) usr, err := a.services.User.GetSelf(r.Context(), requestToken)
// Check the database for the token // Check the database for the token
if err != nil { if err != nil {
return validate.NewRequestError(errors.New("valid authorization header is required"), http.StatusUnauthorized) if ent.IsNotFound(err) {
return validate.NewRequestError(errors.New("valid authorization token is required"), http.StatusUnauthorized)
}
return err
} }
r = r.WithContext(services.SetUserCtx(r.Context(), &usr, requestToken)) r = r.WithContext(services.SetUserCtx(r.Context(), &usr, requestToken))

View file

@ -0,0 +1,2 @@
// Package providers provides a authentication abstraction for the backend.
package providers

View file

@ -47,7 +47,7 @@ func (a *app) mountRoutes(r *chi.Mux, chain *errchain.ErrChain, repos *repo.AllR
// ========================================================================= // =========================================================================
// API Version 1 // API Version 1
v1Base := v1.BaseUrlFunc(prefix) v1Base := v1.BaseURLFunc(prefix)
v1Ctrl := v1.NewControllerV1( v1Ctrl := v1.NewControllerV1(
a.services, a.services,
@ -183,7 +183,7 @@ func notFoundHandler() errchain.HandlerFunc {
if err != nil { if err != nil {
return err return err
} }
defer f.Close() defer func() { _ = f.Close() }()
stat, _ := f.Stat() stat, _ := f.Stat()
if stat.IsDir() { if stat.IsDir() {

View file

@ -1,3 +1,4 @@
// Package services provides the core business logic for the application.
package services package services
import ( import (

View file

@ -62,7 +62,7 @@ func TestMain(m *testing.M) {
tClient = client tClient = client
tRepos = repo.New(tClient, tbus, os.TempDir()+"/homebox") tRepos = repo.New(tClient, tbus, os.TempDir()+"/homebox")
tSvc = New(tRepos) tSvc = New(tRepos)
defer client.Close() defer func() { _ = client.Close() }()
bootstrap() bootstrap()
tCtx = Context{ tCtx = Context{

View file

@ -1,4 +1,4 @@
// / Package eventbus provides an interface for event bus. // Package eventbus provides an interface for event bus.
package eventbus package eventbus
import ( import (

View file

@ -1,3 +1,4 @@
// Package reporting provides a way to import CSV files into the database.
package reporting package reporting
import ( import (

View file

@ -152,7 +152,7 @@ func (s *IOSheet) Read(data io.Reader) error {
return nil return nil
} }
// Write writes the sheet to a writer. // ReadItems writes the sheet to a writer.
func (s *IOSheet) ReadItems(ctx context.Context, items []repo.ItemOut, GID uuid.UUID, repos *repo.AllRepos) error { func (s *IOSheet) ReadItems(ctx context.Context, items []repo.ItemOut, GID uuid.UUID, repos *repo.AllRepos) error {
s.Rows = make([]ExportTSVRow, len(items)) s.Rows = make([]ExportTSVRow, len(items))
@ -162,9 +162,9 @@ func (s *IOSheet) ReadItems(ctx context.Context, items []repo.ItemOut, GID uuid.
item := items[i] item := items[i]
// TODO: Support fetching nested locations // TODO: Support fetching nested locations
locId := item.Location.ID locID := item.Location.ID
locPaths, err := repos.Locations.PathForLoc(context.Background(), GID, locId) locPaths, err := repos.Locations.PathForLoc(context.Background(), GID, locID)
if err != nil { if err != nil {
log.Error().Err(err).Msg("could not get location path") log.Error().Err(err).Msg("could not get location path")
return err return err
@ -252,7 +252,7 @@ func (s *IOSheet) ReadItems(ctx context.Context, items []repo.ItemOut, GID uuid.
return nil return nil
} }
// Writes the current sheet to a writer in TSV format. // TSV writes the current sheet to a writer in TSV format.
func (s *IOSheet) TSV() ([][]string, error) { func (s *IOSheet) TSV() ([][]string, error) {
memcsv := make([][]string, len(s.Rows)+1) memcsv := make([][]string, len(s.Rows)+1)

View file

@ -9,6 +9,7 @@ import (
"github.com/hay-kot/homebox/backend/internal/data/repo" "github.com/hay-kot/homebox/backend/internal/data/repo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
var ( var (
@ -103,9 +104,9 @@ func TestSheet_Read(t *testing.T) {
switch { switch {
case tt.wantErr: case tt.wantErr:
assert.Error(t, err) require.Error(t, err)
default: default:
assert.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, tt.want, sheet.Rows) assert.ElementsMatch(t, tt.want, sheet.Rows)
} }
}) })

View file

@ -32,7 +32,7 @@ func (svc *ItemService) Create(ctx Context, item repo.ItemCreate) (repo.ItemOut,
return repo.ItemOut{}, err return repo.ItemOut{}, err
} }
item.AssetID = repo.AssetID(highest + 1) item.AssetID = highest + 1
} }
return svc.repo.Items.Create(ctx, ctx.GID, item) return svc.repo.Items.Create(ctx, ctx.GID, item)
@ -53,7 +53,7 @@ func (svc *ItemService) EnsureAssetID(ctx context.Context, GID uuid.UUID) (int,
for _, item := range items { for _, item := range items {
highest++ highest++
err = svc.repo.Items.SetAssetID(ctx, GID, item.ID, repo.AssetID(highest)) err = svc.repo.Items.SetAssetID(ctx, GID, item.ID, highest)
if err != nil { if err != nil {
return 0, err return 0, err
} }

View file

@ -12,8 +12,8 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func (svc *ItemService) AttachmentPath(ctx context.Context, attachmentId uuid.UUID) (*ent.Document, error) { func (svc *ItemService) AttachmentPath(ctx context.Context, attachmentID uuid.UUID) (*ent.Document, error) {
attachment, err := svc.repo.Attachments.Get(ctx, attachmentId) attachment, err := svc.repo.Attachments.Get(ctx, attachmentID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -21,7 +21,7 @@ func (svc *ItemService) AttachmentPath(ctx context.Context, attachmentId uuid.UU
return attachment.Edges.Document, nil return attachment.Edges.Document, nil
} }
func (svc *ItemService) AttachmentUpdate(ctx Context, itemId uuid.UUID, data *repo.ItemAttachmentUpdate) (repo.ItemOut, error) { func (svc *ItemService) AttachmentUpdate(ctx Context, itemID uuid.UUID, data *repo.ItemAttachmentUpdate) (repo.ItemOut, error) {
// Update Attachment // Update Attachment
attachment, err := svc.repo.Attachments.Update(ctx, data.ID, data) attachment, err := svc.repo.Attachments.Update(ctx, data.ID, data)
if err != nil { if err != nil {
@ -35,15 +35,15 @@ func (svc *ItemService) AttachmentUpdate(ctx Context, itemId uuid.UUID, data *re
return repo.ItemOut{}, err return repo.ItemOut{}, err
} }
return svc.repo.Items.GetOneByGroup(ctx, ctx.GID, itemId) return svc.repo.Items.GetOneByGroup(ctx, ctx.GID, itemID)
} }
// AttachmentAdd adds an attachment to an item by creating an entry in the Documents table and linking it to the Attachment // AttachmentAdd adds an attachment to an item by creating an entry in the Documents table and linking it to the Attachment
// Table and Items table. The file provided via the reader is stored on the file system based on the provided // Table and Items table. The file provided via the reader is stored on the file system based on the provided
// relative path during construction of the service. // relative path during construction of the service.
func (svc *ItemService) AttachmentAdd(ctx Context, itemId uuid.UUID, filename string, attachmentType attachment.Type, file io.Reader) (repo.ItemOut, error) { func (svc *ItemService) AttachmentAdd(ctx Context, itemID uuid.UUID, filename string, attachmentType attachment.Type, file io.Reader) (repo.ItemOut, error) {
// Get the Item // Get the Item
_, err := svc.repo.Items.GetOneByGroup(ctx, ctx.GID, itemId) _, err := svc.repo.Items.GetOneByGroup(ctx, ctx.GID, itemID)
if err != nil { if err != nil {
return repo.ItemOut{}, err return repo.ItemOut{}, err
} }
@ -56,29 +56,29 @@ func (svc *ItemService) AttachmentAdd(ctx Context, itemId uuid.UUID, filename st
} }
// Create the attachment // Create the attachment
_, err = svc.repo.Attachments.Create(ctx, itemId, doc.ID, attachmentType) _, err = svc.repo.Attachments.Create(ctx, itemID, doc.ID, attachmentType)
if err != nil { if err != nil {
log.Err(err).Msg("failed to create attachment") log.Err(err).Msg("failed to create attachment")
return repo.ItemOut{}, err return repo.ItemOut{}, err
} }
return svc.repo.Items.GetOneByGroup(ctx, ctx.GID, itemId) return svc.repo.Items.GetOneByGroup(ctx, ctx.GID, itemID)
} }
func (svc *ItemService) AttachmentDelete(ctx context.Context, gid, itemId, attachmentId uuid.UUID) error { func (svc *ItemService) AttachmentDelete(ctx context.Context, gid, itemID, attachmentID uuid.UUID) error {
// Get the Item // Get the Item
_, err := svc.repo.Items.GetOneByGroup(ctx, gid, itemId) _, err := svc.repo.Items.GetOneByGroup(ctx, gid, itemID)
if err != nil { if err != nil {
return err return err
} }
attachment, err := svc.repo.Attachments.Get(ctx, attachmentId) attachment, err := svc.repo.Attachments.Get(ctx, attachmentID)
if err != nil { if err != nil {
return err return err
} }
// Delete the attachment // Delete the attachment
err = svc.repo.Attachments.Delete(ctx, attachmentId) err = svc.repo.Attachments.Delete(ctx, attachmentID)
if err != nil { if err != nil {
return err return err
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/hay-kot/homebox/backend/internal/data/repo" "github.com/hay-kot/homebox/backend/internal/data/repo"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestItemService_AddAttachment(t *testing.T) { func TestItemService_AddAttachment(t *testing.T) {
@ -23,7 +24,7 @@ func TestItemService_AddAttachment(t *testing.T) {
Description: "test", Description: "test",
Name: "test", Name: "test",
}) })
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, loc) assert.NotNil(t, loc)
itmC := repo.ItemCreate{ itmC := repo.ItemCreate{
@ -33,11 +34,11 @@ func TestItemService_AddAttachment(t *testing.T) {
} }
itm, err := svc.repo.Items.Create(context.Background(), tGroup.ID, itmC) itm, err := svc.repo.Items.Create(context.Background(), tGroup.ID, itmC)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, itm) assert.NotNil(t, itm)
t.Cleanup(func() { t.Cleanup(func() {
err := svc.repo.Items.Delete(context.Background(), itm.ID) err := svc.repo.Items.Delete(context.Background(), itm.ID)
assert.NoError(t, err) require.NoError(t, err)
}) })
contents := fk.Str(1000) contents := fk.Str(1000)
@ -45,7 +46,7 @@ func TestItemService_AddAttachment(t *testing.T) {
// Setup // Setup
afterAttachment, err := svc.AttachmentAdd(tCtx, itm.ID, "testfile.txt", "attachment", reader) afterAttachment, err := svc.AttachmentAdd(tCtx, itm.ID, "testfile.txt", "attachment", reader)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, afterAttachment) assert.NotNil(t, afterAttachment)
// Check that the file exists // Check that the file exists
@ -56,6 +57,6 @@ func TestItemService_AddAttachment(t *testing.T) {
// Check that the file contents are correct // Check that the file contents are correct
bts, err := os.ReadFile(storedPath) bts, err := os.ReadFile(storedPath)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, contents, string(bts)) assert.Equal(t, contents, string(bts))
} }

View file

@ -16,7 +16,7 @@ var (
oneWeek = time.Hour * 24 * 7 oneWeek = time.Hour * 24 * 7
ErrorInvalidLogin = errors.New("invalid username or password") ErrorInvalidLogin = errors.New("invalid username or password")
ErrorInvalidToken = errors.New("invalid token") ErrorInvalidToken = errors.New("invalid token")
ErrorTokenIdMismatch = errors.New("token id mismatch") ErrorTokenIDMismatch = errors.New("token id mismatch")
) )
type UserService struct { type UserService struct {
@ -134,13 +134,13 @@ func (svc *UserService) UpdateSelf(ctx context.Context, ID uuid.UUID, data repo.
return repo.UserOut{}, err return repo.UserOut{}, err
} }
return svc.repos.Users.GetOneId(ctx, ID) return svc.repos.Users.GetOneID(ctx, ID)
} }
// ============================================================================ // ============================================================================
// User Authentication // User Authentication
func (svc *UserService) createSessionToken(ctx context.Context, userId uuid.UUID, extendedSession bool) (UserAuthTokenDetail, error) { func (svc *UserService) createSessionToken(ctx context.Context, userID uuid.UUID, extendedSession bool) (UserAuthTokenDetail, error) {
attachmentToken := hasher.GenerateToken() attachmentToken := hasher.GenerateToken()
expiresAt := time.Now().Add(oneWeek) expiresAt := time.Now().Add(oneWeek)
@ -149,7 +149,7 @@ func (svc *UserService) createSessionToken(ctx context.Context, userId uuid.UUID
} }
attachmentData := repo.UserAuthTokenCreate{ attachmentData := repo.UserAuthTokenCreate{
UserID: userId, UserID: userID,
TokenHash: attachmentToken.Hash, TokenHash: attachmentToken.Hash,
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
} }
@ -161,7 +161,7 @@ func (svc *UserService) createSessionToken(ctx context.Context, userId uuid.UUID
userToken := hasher.GenerateToken() userToken := hasher.GenerateToken()
data := repo.UserAuthTokenCreate{ data := repo.UserAuthTokenCreate{
UserID: userId, UserID: userID,
TokenHash: userToken.Hash, TokenHash: userToken.Hash,
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
} }
@ -218,7 +218,7 @@ func (svc *UserService) DeleteSelf(ctx context.Context, ID uuid.UUID) error {
} }
func (svc *UserService) ChangePassword(ctx Context, current string, new string) (ok bool) { func (svc *UserService) ChangePassword(ctx Context, current string, new string) (ok bool) {
usr, err := svc.repos.Users.GetOneId(ctx, ctx.UID) usr, err := svc.repos.Users.GetOneID(ctx, ctx.UID)
if err != nil { if err != nil {
return false return false
} }

View file

@ -1,3 +1,4 @@
// Package migrations provides a way to embed the migrations into the binary.
package migrations package migrations
import ( import (

View file

@ -54,7 +54,7 @@ func TestMain(m *testing.M) {
tClient = client tClient = client
tRepos = New(tClient, tbus, os.TempDir()) tRepos = New(tClient, tbus, os.TempDir())
defer client.Close() defer func() { _ = client.Close() }()
bootstrap() bootstrap()

View file

@ -11,6 +11,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hay-kot/homebox/backend/internal/data/ent" "github.com/hay-kot/homebox/backend/internal/data/ent"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func useDocs(t *testing.T, num int) []DocumentOut { func useDocs(t *testing.T, num int) []DocumentOut {
@ -25,7 +26,7 @@ func useDocs(t *testing.T, num int) []DocumentOut {
Content: bytes.NewReader([]byte(fk.Str(10))), Content: bytes.NewReader([]byte(fk.Str(10))),
}) })
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, doc) assert.NotNil(t, doc)
results = append(results, doc) results = append(results, doc)
ids = append(ids, doc.ID) ids = append(ids, doc.ID)
@ -80,31 +81,31 @@ func TestDocumentRepository_CreateUpdateDelete(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Create Document // Create Document
got, err := r.Create(tt.args.ctx, tt.args.gid, tt.args.doc) got, err := r.Create(tt.args.ctx, tt.args.gid, tt.args.doc)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.title, got.Title) assert.Equal(t, tt.title, got.Title)
assert.Equal(t, fmt.Sprintf("%s/%s/documents", temp, tt.args.gid), filepath.Dir(got.Path)) assert.Equal(t, fmt.Sprintf("%s/%s/documents", temp, tt.args.gid), filepath.Dir(got.Path))
ensureRead := func() { ensureRead := func() {
// Read Document // Read Document
bts, err := os.ReadFile(got.Path) bts, err := os.ReadFile(got.Path)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.content, string(bts)) assert.Equal(t, tt.content, string(bts))
} }
ensureRead() ensureRead()
// Update Document // Update Document
got, err = r.Rename(tt.args.ctx, got.ID, "__"+tt.title+"__") got, err = r.Rename(tt.args.ctx, got.ID, "__"+tt.title+"__")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "__"+tt.title+"__", got.Title) assert.Equal(t, "__"+tt.title+"__", got.Title)
ensureRead() ensureRead()
// Delete Document // Delete Document
err = r.Delete(tt.args.ctx, got.ID) err = r.Delete(tt.args.ctx, got.ID)
assert.NoError(t, err) require.NoError(t, err)
_, err = os.Stat(got.Path) _, err = os.Stat(got.Path)
assert.Error(t, err) require.Error(t, err)
}) })
} }
} }

View file

@ -5,29 +5,30 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_Group_Create(t *testing.T) { func Test_Group_Create(t *testing.T) {
g, err := tRepos.Groups.GroupCreate(context.Background(), "test") g, err := tRepos.Groups.GroupCreate(context.Background(), "test")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "test", g.Name) assert.Equal(t, "test", g.Name)
// Get by ID // Get by ID
foundGroup, err := tRepos.Groups.GroupByID(context.Background(), g.ID) foundGroup, err := tRepos.Groups.GroupByID(context.Background(), g.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, g.ID, foundGroup.ID) assert.Equal(t, g.ID, foundGroup.ID)
} }
func Test_Group_Update(t *testing.T) { func Test_Group_Update(t *testing.T) {
g, err := tRepos.Groups.GroupCreate(context.Background(), "test") g, err := tRepos.Groups.GroupCreate(context.Background(), "test")
assert.NoError(t, err) require.NoError(t, err)
g, err = tRepos.Groups.GroupUpdate(context.Background(), g.ID, GroupUpdate{ g, err = tRepos.Groups.GroupUpdate(context.Background(), g.ID, GroupUpdate{
Name: "test2", Name: "test2",
Currency: "eur", Currency: "eur",
}) })
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "test2", g.Name) assert.Equal(t, "test2", g.Name)
assert.Equal(t, "EUR", g.Currency) assert.Equal(t, "EUR", g.Currency)
} }
@ -38,7 +39,7 @@ func Test_Group_GroupStatistics(t *testing.T) {
stats, err := tRepos.Groups.StatsGroup(context.Background(), tGroup.ID) stats, err := tRepos.Groups.StatsGroup(context.Background(), tGroup.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 20, stats.TotalItems) assert.Equal(t, 20, stats.TotalItems)
assert.Equal(t, 20, stats.TotalLabels) assert.Equal(t, 20, stats.TotalLabels)
assert.Equal(t, 1, stats.TotalUsers) assert.Equal(t, 1, stats.TotalUsers)

View file

@ -51,18 +51,18 @@ func ToItemAttachment(attachment *ent.Attachment) ItemAttachment {
} }
} }
func (r *AttachmentRepo) Create(ctx context.Context, itemId, docId uuid.UUID, typ attachment.Type) (*ent.Attachment, error) { func (r *AttachmentRepo) Create(ctx context.Context, itemID, docID uuid.UUID, typ attachment.Type) (*ent.Attachment, error) {
bldr := r.db.Attachment.Create(). bldr := r.db.Attachment.Create().
SetType(typ). SetType(typ).
SetDocumentID(docId). SetDocumentID(docID).
SetItemID(itemId) SetItemID(itemID)
// Autoset primary to true if this is the first attachment // Autoset primary to true if this is the first attachment
// that is of type photo // that is of type photo
if typ == attachment.TypePhoto { if typ == attachment.TypePhoto {
cnt, err := r.db.Attachment.Query(). cnt, err := r.db.Attachment.Query().
Where( Where(
attachment.HasItemWith(item.ID(itemId)), attachment.HasItemWith(item.ID(itemID)),
attachment.TypeEQ(typ), attachment.TypeEQ(typ),
). ).
Count(ctx) Count(ctx)
@ -87,11 +87,11 @@ func (r *AttachmentRepo) Get(ctx context.Context, id uuid.UUID) (*ent.Attachment
Only(ctx) Only(ctx)
} }
func (r *AttachmentRepo) Update(ctx context.Context, itemId uuid.UUID, data *ItemAttachmentUpdate) (*ent.Attachment, error) { func (r *AttachmentRepo) Update(ctx context.Context, itemID uuid.UUID, data *ItemAttachmentUpdate) (*ent.Attachment, error) {
// TODO: execute within Tx // TODO: execute within Tx
typ := attachment.Type(data.Type) typ := attachment.Type(data.Type)
bldr := r.db.Attachment.UpdateOneID(itemId). bldr := r.db.Attachment.UpdateOneID(itemID).
SetType(typ) SetType(typ)
// Primary only applies to photos // Primary only applies to photos
@ -109,7 +109,7 @@ func (r *AttachmentRepo) Update(ctx context.Context, itemId uuid.UUID, data *Ite
// Ensure all other attachments are not primary // Ensure all other attachments are not primary
err = r.db.Attachment.Update(). err = r.db.Attachment.Update().
Where( Where(
attachment.HasItemWith(item.ID(itemId)), attachment.HasItemWith(item.ID(itemID)),
attachment.IDNEQ(itm.ID), attachment.IDNEQ(itm.ID),
). ).
SetPrimary(false). SetPrimary(false).

View file

@ -8,6 +8,7 @@ import (
"github.com/hay-kot/homebox/backend/internal/data/ent" "github.com/hay-kot/homebox/backend/internal/data/ent"
"github.com/hay-kot/homebox/backend/internal/data/ent/attachment" "github.com/hay-kot/homebox/backend/internal/data/ent/attachment"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestAttachmentRepo_Create(t *testing.T) { func TestAttachmentRepo_Create(t *testing.T) {
@ -23,8 +24,8 @@ func TestAttachmentRepo_Create(t *testing.T) {
type args struct { type args struct {
ctx context.Context ctx context.Context
itemId uuid.UUID itemID uuid.UUID
docId uuid.UUID docID uuid.UUID
typ attachment.Type typ attachment.Type
} }
tests := []struct { tests := []struct {
@ -37,8 +38,8 @@ func TestAttachmentRepo_Create(t *testing.T) {
name: "create attachment", name: "create attachment",
args: args{ args: args{
ctx: context.Background(), ctx: context.Background(),
itemId: item.ID, itemID: item.ID,
docId: doc.ID, docID: doc.ID,
typ: attachment.TypePhoto, typ: attachment.TypePhoto,
}, },
want: &ent.Attachment{ want: &ent.Attachment{
@ -49,8 +50,8 @@ func TestAttachmentRepo_Create(t *testing.T) {
name: "create attachment with invalid item id", name: "create attachment with invalid item id",
args: args{ args: args{
ctx: context.Background(), ctx: context.Background(),
itemId: uuid.New(), itemID: uuid.New(),
docId: doc.ID, docID: doc.ID,
typ: "blarg", typ: "blarg",
}, },
wantErr: true, wantErr: true,
@ -58,7 +59,7 @@ func TestAttachmentRepo_Create(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := tRepos.Attachments.Create(tt.args.ctx, tt.args.itemId, tt.args.docId, tt.args.typ) got, err := tRepos.Attachments.Create(tt.args.ctx, tt.args.itemID, tt.args.docID, tt.args.typ)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("AttachmentRepo.Create() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("AttachmentRepo.Create() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -71,9 +72,9 @@ func TestAttachmentRepo_Create(t *testing.T) {
assert.Equal(t, tt.want.Type, got.Type) assert.Equal(t, tt.want.Type, got.Type)
withItems, err := tRepos.Attachments.Get(tt.args.ctx, got.ID) withItems, err := tRepos.Attachments.Get(tt.args.ctx, got.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, tt.args.itemId, withItems.Edges.Item.ID) assert.Equal(t, tt.args.itemID, withItems.Edges.Item.ID)
assert.Equal(t, tt.args.docId, withItems.Edges.Document.ID) assert.Equal(t, tt.args.docID, withItems.Edges.Document.ID)
ids = append(ids, got.ID) ids = append(ids, got.ID)
}) })
@ -96,7 +97,7 @@ func useAttachments(t *testing.T, n int) []*ent.Attachment {
attachments := make([]*ent.Attachment, n) attachments := make([]*ent.Attachment, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
attachment, err := tRepos.Attachments.Create(context.Background(), item.ID, doc.ID, attachment.TypePhoto) attachment, err := tRepos.Attachments.Create(context.Background(), item.ID, doc.ID, attachment.TypePhoto)
assert.NoError(t, err) require.NoError(t, err)
attachments[i] = attachment attachments[i] = attachment
ids = append(ids, attachment.ID) ids = append(ids, attachment.ID)
@ -114,10 +115,10 @@ func TestAttachmentRepo_Update(t *testing.T) {
Type: string(typ), Type: string(typ),
}) })
assert.NoError(t, err) require.NoError(t, err)
updated, err := tRepos.Attachments.Get(context.Background(), entity.ID) updated, err := tRepos.Attachments.Get(context.Background(), entity.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, typ, updated.Type) assert.Equal(t, typ, updated.Type)
}) })
} }
@ -127,8 +128,8 @@ func TestAttachmentRepo_Delete(t *testing.T) {
entity := useAttachments(t, 1)[0] entity := useAttachments(t, 1)[0]
err := tRepos.Attachments.Delete(context.Background(), entity.ID) err := tRepos.Attachments.Delete(context.Background(), entity.ID)
assert.NoError(t, err) require.NoError(t, err)
_, err = tRepos.Attachments.Get(context.Background(), entity.ID) _, err = tRepos.Attachments.Get(context.Background(), entity.ID)
assert.Error(t, err) require.Error(t, err)
} }

View file

@ -276,9 +276,9 @@ func mapItemOut(item *ent.Item) ItemOut {
} }
} }
func (r *ItemsRepository) publishMutationEvent(GID uuid.UUID) { func (e *ItemsRepository) publishMutationEvent(GID uuid.UUID) {
if r.bus != nil { if e.bus != nil {
r.bus.Publish(eventbus.EventItemMutation, eventbus.GroupMutationEvent{GID: GID}) e.bus.Publish(eventbus.EventItemMutation, eventbus.GroupMutationEvent{GID: GID})
} }
} }

View file

@ -22,7 +22,7 @@ func useItems(t *testing.T, len int) []ItemOut {
t.Helper() t.Helper()
location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory())
assert.NoError(t, err) require.NoError(t, err)
items := make([]ItemOut, len) items := make([]ItemOut, len)
for i := 0; i < len; i++ { for i := 0; i < len; i++ {
@ -30,7 +30,7 @@ func useItems(t *testing.T, len int) []ItemOut {
itm.LocationID = location.ID itm.LocationID = location.ID
item, err := tRepos.Items.Create(context.Background(), tGroup.ID, itm) item, err := tRepos.Items.Create(context.Background(), tGroup.ID, itm)
assert.NoError(t, err) require.NoError(t, err)
items[i] = item items[i] = item
} }
@ -61,23 +61,22 @@ func TestItemsRepository_RecursiveRelationships(t *testing.T) {
// Append Parent ID // Append Parent ID
_, err := tRepos.Items.UpdateByGroup(context.Background(), tGroup.ID, update) _, err := tRepos.Items.UpdateByGroup(context.Background(), tGroup.ID, update)
assert.NoError(t, err) require.NoError(t, err)
// Check Parent ID // Check Parent ID
updated, err := tRepos.Items.GetOne(context.Background(), child.ID) updated, err := tRepos.Items.GetOne(context.Background(), child.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, parent.ID, updated.Parent.ID) assert.Equal(t, parent.ID, updated.Parent.ID)
// Remove Parent ID // Remove Parent ID
update.ParentID = uuid.Nil update.ParentID = uuid.Nil
_, err = tRepos.Items.UpdateByGroup(context.Background(), tGroup.ID, update) _, err = tRepos.Items.UpdateByGroup(context.Background(), tGroup.ID, update)
assert.NoError(t, err) require.NoError(t, err)
// Check Parent ID // Check Parent ID
updated, err = tRepos.Items.GetOne(context.Background(), child.ID) updated, err = tRepos.Items.GetOne(context.Background(), child.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Nil(t, updated.Parent) assert.Nil(t, updated.Parent)
} }
} }
@ -86,7 +85,7 @@ func TestItemsRepository_GetOne(t *testing.T) {
for _, item := range entity { for _, item := range entity {
result, err := tRepos.Items.GetOne(context.Background(), item.ID) result, err := tRepos.Items.GetOne(context.Background(), item.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, item.ID, result.ID) assert.Equal(t, item.ID, result.ID)
} }
} }
@ -96,9 +95,9 @@ func TestItemsRepository_GetAll(t *testing.T) {
expected := useItems(t, length) expected := useItems(t, length)
results, err := tRepos.Items.GetAll(context.Background(), tGroup.ID) results, err := tRepos.Items.GetAll(context.Background(), tGroup.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, length, len(results)) assert.Len(t, results, length)
for _, item := range results { for _, item := range results {
for _, expectedItem := range expected { for _, expectedItem := range expected {
@ -113,23 +112,23 @@ func TestItemsRepository_GetAll(t *testing.T) {
func TestItemsRepository_Create(t *testing.T) { func TestItemsRepository_Create(t *testing.T) {
location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory())
assert.NoError(t, err) require.NoError(t, err)
itm := itemFactory() itm := itemFactory()
itm.LocationID = location.ID itm.LocationID = location.ID
result, err := tRepos.Items.Create(context.Background(), tGroup.ID, itm) result, err := tRepos.Items.Create(context.Background(), tGroup.ID, itm)
assert.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, result.ID) assert.NotEmpty(t, result.ID)
// Cleanup - Also deletes item // Cleanup - Also deletes item
err = tRepos.Locations.delete(context.Background(), location.ID) err = tRepos.Locations.delete(context.Background(), location.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestItemsRepository_Create_Location(t *testing.T) { func TestItemsRepository_Create_Location(t *testing.T) {
location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory())
assert.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, location.ID) assert.NotEmpty(t, location.ID)
item := itemFactory() item := itemFactory()
@ -137,18 +136,18 @@ func TestItemsRepository_Create_Location(t *testing.T) {
// Create Resource // Create Resource
result, err := tRepos.Items.Create(context.Background(), tGroup.ID, item) result, err := tRepos.Items.Create(context.Background(), tGroup.ID, item)
assert.NoError(t, err) require.NoError(t, err)
assert.NotEmpty(t, result.ID) assert.NotEmpty(t, result.ID)
// Get Resource // Get Resource
foundItem, err := tRepos.Items.GetOne(context.Background(), result.ID) foundItem, err := tRepos.Items.GetOne(context.Background(), result.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, result.ID, foundItem.ID) assert.Equal(t, result.ID, foundItem.ID)
assert.Equal(t, location.ID, foundItem.Location.ID) assert.Equal(t, location.ID, foundItem.Location.ID)
// Cleanup - Also deletes item // Cleanup - Also deletes item
err = tRepos.Locations.delete(context.Background(), location.ID) err = tRepos.Locations.delete(context.Background(), location.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestItemsRepository_Delete(t *testing.T) { func TestItemsRepository_Delete(t *testing.T) {
@ -156,11 +155,11 @@ func TestItemsRepository_Delete(t *testing.T) {
for _, item := range entities { for _, item := range entities {
err := tRepos.Items.Delete(context.Background(), item.ID) err := tRepos.Items.Delete(context.Background(), item.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
results, err := tRepos.Items.GetAll(context.Background(), tGroup.ID) results, err := tRepos.Items.GetAll(context.Background(), tGroup.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, results) assert.Empty(t, results)
} }
@ -213,7 +212,7 @@ func TestItemsRepository_Update_Labels(t *testing.T) {
} }
updated, err := tRepos.Items.UpdateByGroup(context.Background(), tGroup.ID, updateData) updated, err := tRepos.Items.UpdateByGroup(context.Background(), tGroup.ID, updateData)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, tt.want, len(updated.Labels)) assert.Len(t, tt.want, len(updated.Labels))
for _, label := range updated.Labels { for _, label := range updated.Labels {
@ -250,10 +249,10 @@ func TestItemsRepository_Update(t *testing.T) {
} }
updatedEntity, err := tRepos.Items.UpdateByGroup(context.Background(), tGroup.ID, updateData) updatedEntity, err := tRepos.Items.UpdateByGroup(context.Background(), tGroup.ID, updateData)
assert.NoError(t, err) require.NoError(t, err)
got, err := tRepos.Items.GetOne(context.Background(), updatedEntity.ID) got, err := tRepos.Items.GetOne(context.Background(), updatedEntity.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, updateData.ID, got.ID) assert.Equal(t, updateData.ID, got.ID)
assert.Equal(t, updateData.Name, got.Name) assert.Equal(t, updateData.Name, got.Name)
@ -263,10 +262,10 @@ func TestItemsRepository_Update(t *testing.T) {
assert.Equal(t, updateData.Manufacturer, got.Manufacturer) assert.Equal(t, updateData.Manufacturer, got.Manufacturer)
// assert.Equal(t, updateData.PurchaseTime, got.PurchaseTime) // assert.Equal(t, updateData.PurchaseTime, got.PurchaseTime)
assert.Equal(t, updateData.PurchaseFrom, got.PurchaseFrom) assert.Equal(t, updateData.PurchaseFrom, got.PurchaseFrom)
assert.Equal(t, updateData.PurchasePrice, got.PurchasePrice) assert.InDelta(t, updateData.PurchasePrice, got.PurchasePrice, 0.01)
// assert.Equal(t, updateData.SoldTime, got.SoldTime) // assert.Equal(t, updateData.SoldTime, got.SoldTime)
assert.Equal(t, updateData.SoldTo, got.SoldTo) assert.Equal(t, updateData.SoldTo, got.SoldTo)
assert.Equal(t, updateData.SoldPrice, got.SoldPrice) assert.InDelta(t, updateData.SoldPrice, got.SoldPrice, 0.01)
assert.Equal(t, updateData.SoldNotes, got.SoldNotes) assert.Equal(t, updateData.SoldNotes, got.SoldNotes)
assert.Equal(t, updateData.Notes, got.Notes) assert.Equal(t, updateData.Notes, got.Notes)
// assert.Equal(t, updateData.WarrantyExpires, got.WarrantyExpires) // assert.Equal(t, updateData.WarrantyExpires, got.WarrantyExpires)
@ -275,15 +274,15 @@ func TestItemsRepository_Update(t *testing.T) {
} }
func TestItemRepository_GetAllCustomFields(t *testing.T) { func TestItemRepository_GetAllCustomFields(t *testing.T) {
const FIELDS_COUNT = 5 const FieldsCount = 5
entity := useItems(t, 1)[0] entity := useItems(t, 1)[0]
fields := make([]ItemField, FIELDS_COUNT) fields := make([]ItemField, FieldsCount)
names := make([]string, FIELDS_COUNT) names := make([]string, FieldsCount)
values := make([]string, FIELDS_COUNT) values := make([]string, FieldsCount)
for i := 0; i < FIELDS_COUNT; i++ { for i := 0; i < FieldsCount; i++ {
name := fk.Str(10) name := fk.Str(10)
fields[i] = ItemField{ fields[i] = ItemField{
Name: name, Name: name,
@ -306,7 +305,7 @@ func TestItemRepository_GetAllCustomFields(t *testing.T) {
// Test getting all fields // Test getting all fields
{ {
results, err := tRepos.Items.GetAllCustomFieldNames(context.Background(), tGroup.ID) results, err := tRepos.Items.GetAllCustomFieldNames(context.Background(), tGroup.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, names, results) assert.ElementsMatch(t, names, results)
} }
@ -314,7 +313,7 @@ func TestItemRepository_GetAllCustomFields(t *testing.T) {
{ {
results, err := tRepos.Items.GetAllCustomFieldValues(context.Background(), tUser.GroupID, names[0]) results, err := tRepos.Items.GetAllCustomFieldValues(context.Background(), tUser.GroupID, names[0])
assert.NoError(t, err) require.NoError(t, err)
assert.ElementsMatch(t, values[:1], results) assert.ElementsMatch(t, values[:1], results)
} }
} }

View file

@ -87,28 +87,28 @@ func (r *LabelRepository) GetOneByGroup(ctx context.Context, gid, ld uuid.UUID)
return r.getOne(ctx, label.ID(ld), label.HasGroupWith(group.ID(gid))) return r.getOne(ctx, label.ID(ld), label.HasGroupWith(group.ID(gid)))
} }
func (r *LabelRepository) GetAll(ctx context.Context, groupId uuid.UUID) ([]LabelSummary, error) { func (r *LabelRepository) GetAll(ctx context.Context, groupID uuid.UUID) ([]LabelSummary, error) {
return mapLabelsOut(r.db.Label.Query(). return mapLabelsOut(r.db.Label.Query().
Where(label.HasGroupWith(group.ID(groupId))). Where(label.HasGroupWith(group.ID(groupID))).
Order(ent.Asc(label.FieldName)). Order(ent.Asc(label.FieldName)).
WithGroup(). WithGroup().
All(ctx), All(ctx),
) )
} }
func (r *LabelRepository) Create(ctx context.Context, groupdId uuid.UUID, data LabelCreate) (LabelOut, error) { func (r *LabelRepository) Create(ctx context.Context, groupID uuid.UUID, data LabelCreate) (LabelOut, error) {
label, err := r.db.Label.Create(). label, err := r.db.Label.Create().
SetName(data.Name). SetName(data.Name).
SetDescription(data.Description). SetDescription(data.Description).
SetColor(data.Color). SetColor(data.Color).
SetGroupID(groupdId). SetGroupID(groupID).
Save(ctx) Save(ctx)
if err != nil { if err != nil {
return LabelOut{}, err return LabelOut{}, err
} }
label.Edges.Group = &ent.Group{ID: groupdId} // bootstrap group ID label.Edges.Group = &ent.Group{ID: groupID} // bootstrap group ID
r.publishMutationEvent(groupdId) r.publishMutationEvent(groupID)
return mapLabelOut(label), err return mapLabelOut(label), err
} }

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func labelFactory() LabelCreate { func labelFactory() LabelCreate {
@ -22,7 +23,7 @@ func useLabels(t *testing.T, len int) []LabelOut {
itm := labelFactory() itm := labelFactory()
item, err := tRepos.Labels.Create(context.Background(), tGroup.ID, itm) item, err := tRepos.Labels.Create(context.Background(), tGroup.ID, itm)
assert.NoError(t, err) require.NoError(t, err)
labels[i] = item labels[i] = item
} }
@ -41,7 +42,7 @@ func TestLabelRepository_Get(t *testing.T) {
// Get by ID // Get by ID
foundLoc, err := tRepos.Labels.GetOne(context.Background(), label.ID) foundLoc, err := tRepos.Labels.GetOne(context.Background(), label.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, label.ID, foundLoc.ID) assert.Equal(t, label.ID, foundLoc.ID)
} }
@ -49,26 +50,26 @@ func TestLabelRepositoryGetAll(t *testing.T) {
useLabels(t, 10) useLabels(t, 10)
all, err := tRepos.Labels.GetAll(context.Background(), tGroup.ID) all, err := tRepos.Labels.GetAll(context.Background(), tGroup.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, all, 10) assert.Len(t, all, 10)
} }
func TestLabelRepository_Create(t *testing.T) { func TestLabelRepository_Create(t *testing.T) {
loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory()) loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory())
assert.NoError(t, err) require.NoError(t, err)
// Get by ID // Get by ID
foundLoc, err := tRepos.Labels.GetOne(context.Background(), loc.ID) foundLoc, err := tRepos.Labels.GetOne(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, loc.ID, foundLoc.ID) assert.Equal(t, loc.ID, foundLoc.ID)
err = tRepos.Labels.delete(context.Background(), loc.ID) err = tRepos.Labels.delete(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestLabelRepository_Update(t *testing.T) { func TestLabelRepository_Update(t *testing.T) {
loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory()) loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory())
assert.NoError(t, err) require.NoError(t, err)
updateData := LabelUpdate{ updateData := LabelUpdate{
ID: loc.ID, ID: loc.ID,
@ -77,26 +78,26 @@ func TestLabelRepository_Update(t *testing.T) {
} }
update, err := tRepos.Labels.UpdateByGroup(context.Background(), tGroup.ID, updateData) update, err := tRepos.Labels.UpdateByGroup(context.Background(), tGroup.ID, updateData)
assert.NoError(t, err) require.NoError(t, err)
foundLoc, err := tRepos.Labels.GetOne(context.Background(), loc.ID) foundLoc, err := tRepos.Labels.GetOne(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, update.ID, foundLoc.ID) assert.Equal(t, update.ID, foundLoc.ID)
assert.Equal(t, update.Name, foundLoc.Name) assert.Equal(t, update.Name, foundLoc.Name)
assert.Equal(t, update.Description, foundLoc.Description) assert.Equal(t, update.Description, foundLoc.Description)
err = tRepos.Labels.delete(context.Background(), loc.ID) err = tRepos.Labels.delete(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestLabelRepository_Delete(t *testing.T) { func TestLabelRepository_Delete(t *testing.T) {
loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory()) loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory())
assert.NoError(t, err) require.NoError(t, err)
err = tRepos.Labels.delete(context.Background(), loc.ID) err = tRepos.Labels.delete(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
_, err = tRepos.Labels.GetOne(context.Background(), loc.ID) _, err = tRepos.Labels.GetOne(context.Background(), loc.ID)
assert.Error(t, err) require.Error(t, err)
} }

View file

@ -99,7 +99,7 @@ type LocationQuery struct {
FilterChildren bool `json:"filterChildren" schema:"filterChildren"` FilterChildren bool `json:"filterChildren" schema:"filterChildren"`
} }
// GetALlWithCount returns all locations with item count field populated // GetAll returns all locations with item count field populated
func (r *LocationRepository) GetAll(ctx context.Context, GID uuid.UUID, filter LocationQuery) ([]LocationOutCount, error) { func (r *LocationRepository) GetAll(ctx context.Context, GID uuid.UUID, filter LocationQuery) ([]LocationOutCount, error) {
query := `--sql query := `--sql
SELECT SELECT
@ -135,7 +135,7 @@ func (r *LocationRepository) GetAll(ctx context.Context, GID uuid.UUID, filter L
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer func() { _ = rows.Close() }()
list := []LocationOutCount{} list := []LocationOutCount{}
for rows.Next() { for rows.Next() {
@ -265,7 +265,7 @@ type LocationPath struct {
Name string `json:"name"` Name string `json:"name"`
} }
func (lr *LocationRepository) PathForLoc(ctx context.Context, GID, locID uuid.UUID) ([]LocationPath, error) { func (r *LocationRepository) PathForLoc(ctx context.Context, GID, locID uuid.UUID) ([]LocationPath, error) {
query := `WITH RECURSIVE location_path AS ( query := `WITH RECURSIVE location_path AS (
SELECT id, name, location_children SELECT id, name, location_children
FROM locations FROM locations
@ -282,11 +282,11 @@ func (lr *LocationRepository) PathForLoc(ctx context.Context, GID, locID uuid.UU
SELECT id, name SELECT id, name
FROM location_path` FROM location_path`
rows, err := lr.db.Sql().QueryContext(ctx, query, locID, GID) rows, err := r.db.Sql().QueryContext(ctx, query, locID, GID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var locations []LocationPath var locations []LocationPath
@ -311,7 +311,7 @@ func (lr *LocationRepository) PathForLoc(ctx context.Context, GID, locID uuid.UU
return locations, nil return locations, nil
} }
func (lr *LocationRepository) Tree(ctx context.Context, GID uuid.UUID, tq TreeQuery) ([]TreeItem, error) { func (r *LocationRepository) Tree(ctx context.Context, GID uuid.UUID, tq TreeQuery) ([]TreeItem, error) {
query := ` query := `
WITH recursive location_tree(id, NAME, parent_id, level, node_type) AS WITH recursive location_tree(id, NAME, parent_id, level, node_type) AS
( (
@ -393,11 +393,11 @@ func (lr *LocationRepository) Tree(ctx context.Context, GID uuid.UUID, tq TreeQu
query = strings.ReplaceAll(query, "{{ WITH_ITEMS_FROM }}", "") query = strings.ReplaceAll(query, "{{ WITH_ITEMS_FROM }}", "")
} }
rows, err := lr.db.Sql().QueryContext(ctx, query, GID) rows, err := r.db.Sql().QueryContext(ctx, query, GID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var locations []FlatTreeItem var locations []FlatTreeItem
for rows.Next() { for rows.Next() {

View file

@ -8,6 +8,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hay-kot/homebox/backend/internal/data/ent" "github.com/hay-kot/homebox/backend/internal/data/ent"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func locationFactory() LocationCreate { func locationFactory() LocationCreate {
@ -24,7 +25,7 @@ func useLocations(t *testing.T, len int) []LocationOut {
for i := 0; i < len; i++ { for i := 0; i < len; i++ {
loc, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) loc, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory())
assert.NoError(t, err) require.NoError(t, err)
out[i] = loc out[i] = loc
} }
@ -42,15 +43,15 @@ func useLocations(t *testing.T, len int) []LocationOut {
func TestLocationRepository_Get(t *testing.T) { func TestLocationRepository_Get(t *testing.T) {
loc, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) loc, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory())
assert.NoError(t, err) require.NoError(t, err)
// Get by ID // Get by ID
foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID) foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, loc.ID, foundLoc.ID) assert.Equal(t, loc.ID, foundLoc.ID)
err = tRepos.Locations.delete(context.Background(), loc.ID) err = tRepos.Locations.delete(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestLocationRepositoryGetAllWithCount(t *testing.T) { func TestLocationRepositoryGetAllWithCount(t *testing.T) {
@ -63,10 +64,10 @@ func TestLocationRepositoryGetAllWithCount(t *testing.T) {
LocationID: result.ID, LocationID: result.ID,
}) })
assert.NoError(t, err) require.NoError(t, err)
results, err := tRepos.Locations.GetAll(context.Background(), tGroup.ID, LocationQuery{}) results, err := tRepos.Locations.GetAll(context.Background(), tGroup.ID, LocationQuery{})
assert.NoError(t, err) require.NoError(t, err)
for _, loc := range results { for _, loc := range results {
if loc.ID == result.ID { if loc.ID == result.ID {
@ -80,11 +81,11 @@ func TestLocationRepository_Create(t *testing.T) {
// Get by ID // Get by ID
foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID) foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, loc.ID, foundLoc.ID) assert.Equal(t, loc.ID, foundLoc.ID)
err = tRepos.Locations.delete(context.Background(), loc.ID) err = tRepos.Locations.delete(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestLocationRepository_Update(t *testing.T) { func TestLocationRepository_Update(t *testing.T) {
@ -97,27 +98,27 @@ func TestLocationRepository_Update(t *testing.T) {
} }
update, err := tRepos.Locations.UpdateByGroup(context.Background(), tGroup.ID, updateData.ID, updateData) update, err := tRepos.Locations.UpdateByGroup(context.Background(), tGroup.ID, updateData.ID, updateData)
assert.NoError(t, err) require.NoError(t, err)
foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID) foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, update.ID, foundLoc.ID) assert.Equal(t, update.ID, foundLoc.ID)
assert.Equal(t, update.Name, foundLoc.Name) assert.Equal(t, update.Name, foundLoc.Name)
assert.Equal(t, update.Description, foundLoc.Description) assert.Equal(t, update.Description, foundLoc.Description)
err = tRepos.Locations.delete(context.Background(), loc.ID) err = tRepos.Locations.delete(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestLocationRepository_Delete(t *testing.T) { func TestLocationRepository_Delete(t *testing.T) {
loc := useLocations(t, 1)[0] loc := useLocations(t, 1)[0]
err := tRepos.Locations.delete(context.Background(), loc.ID) err := tRepos.Locations.delete(context.Background(), loc.ID)
assert.NoError(t, err) require.NoError(t, err)
_, err = tRepos.Locations.Get(context.Background(), loc.ID) _, err = tRepos.Locations.Get(context.Background(), loc.ID)
assert.Error(t, err) require.Error(t, err)
} }
func TestItemRepository_TreeQuery(t *testing.T) { func TestItemRepository_TreeQuery(t *testing.T) {
@ -130,18 +131,18 @@ func TestItemRepository_TreeQuery(t *testing.T) {
Name: locs[0].Name, Name: locs[0].Name,
Description: locs[0].Description, Description: locs[0].Description,
}) })
assert.NoError(t, err) require.NoError(t, err)
locations, err := tRepos.Locations.Tree(context.Background(), tGroup.ID, TreeQuery{WithItems: true}) locations, err := tRepos.Locations.Tree(context.Background(), tGroup.ID, TreeQuery{WithItems: true})
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 2, len(locations)) assert.Len(t, locations, 2)
// Check roots // Check roots
for _, loc := range locations { for _, loc := range locations {
if loc.ID == locs[1].ID { if loc.ID == locs[1].ID {
assert.Equal(t, 1, len(loc.Children)) assert.Len(t, loc.Children, 1)
} }
} }
} }
@ -157,15 +158,15 @@ func TestLocationRepository_PathForLoc(t *testing.T) {
Name: locs[i].Name, Name: locs[i].Name,
Description: locs[i].Description, Description: locs[i].Description,
}) })
assert.NoError(t, err) require.NoError(t, err)
} }
last := locs[0] last := locs[0]
path, err := tRepos.Locations.PathForLoc(context.Background(), tGroup.ID, last.ID) path, err := tRepos.Locations.PathForLoc(context.Background(), tGroup.ID, last.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, 3, len(path)) assert.Len(t, path, 3)
// Check path and order // Check path and order
for i, loc := range path { for i, loc := range path {

View file

@ -152,7 +152,6 @@ func (r *MaintenanceEntryRepository) GetLog(ctx context.Context, groupID, itemID
maintenanceentry.DateNotNil(), maintenanceentry.DateNotNil(),
maintenanceentry.DateNEQ(time.Time{}), maintenanceentry.DateNEQ(time.Time{}),
)) ))
} else if query.Scheduled { } else if query.Scheduled {
q = q.Where(maintenanceentry.And( q = q.Where(maintenanceentry.And(
maintenanceentry.Or( maintenanceentry.Or(

View file

@ -7,6 +7,7 @@ import (
"github.com/hay-kot/homebox/backend/internal/data/types" "github.com/hay-kot/homebox/backend/internal/data/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
// get the previous month from the current month, accounts for errors when run // get the previous month from the current month, accounts for errors when run
@ -67,7 +68,7 @@ func TestMaintenanceEntryRepository_GetLog(t *testing.T) {
} }
assert.Equal(t, item.ID, log.ItemID) assert.Equal(t, item.ID, log.ItemID)
assert.Equal(t, 10, len(log.Entries)) assert.Len(t, log.Entries, 10)
// Calculate the average cost // Calculate the average cost
var total float64 var total float64
@ -76,11 +77,11 @@ func TestMaintenanceEntryRepository_GetLog(t *testing.T) {
total += entry.Cost total += entry.Cost
} }
assert.Equal(t, total, log.CostTotal, "total cost should be equal to the sum of all entries") assert.InDelta(t, total, log.CostTotal, .001, "total cost should be equal to the sum of all entries")
assert.Equal(t, total/2, log.CostAverage, "average cost should be the average of the two months") assert.InDelta(t, total/2, log.CostAverage, 001, "average cost should be the average of the two months")
for _, entry := range log.Entries { for _, entry := range log.Entries {
err := tRepos.MaintEntry.Delete(context.Background(), entry.ID) err := tRepos.MaintEntry.Delete(context.Background(), entry.ID)
assert.NoError(t, err) require.NoError(t, err)
} }
} }

View file

@ -43,7 +43,7 @@ type (
NotifierUpdate struct { NotifierUpdate struct {
Name string `json:"name" validate:"required,min=1,max=255"` Name string `json:"name" validate:"required,min=1,max=255"`
IsActive bool `json:"isActive"` IsActive bool `json:"isActive"`
URL *string `json:"url" validate:"omitempty,shoutrrr" extensions:"x-nullable" ` URL *string `json:"url" validate:"omitempty,shoutrrr" extensions:"x-nullable"`
} }
NotifierOut struct { NotifierOut struct {

View file

@ -71,7 +71,7 @@ func (r *TokenRepository) GetRoles(ctx context.Context, token string) (*set.Set[
return &roleSet, nil return &roleSet, nil
} }
// Creates a token for a user // CreateToken Creates a token for a user
func (r *TokenRepository) CreateToken(ctx context.Context, createToken UserAuthTokenCreate, roles ...authroles.Role) (UserAuthToken, error) { func (r *TokenRepository) CreateToken(ctx context.Context, createToken UserAuthTokenCreate, roles ...authroles.Role) (UserAuthToken, error) {
dbToken, err := r.db.AuthTokens.Create(). dbToken, err := r.db.AuthTokens.Create().
SetToken(createToken.TokenHash). SetToken(createToken.TokenHash).

View file

@ -7,15 +7,15 @@ import (
"github.com/hay-kot/homebox/backend/pkgs/hasher" "github.com/hay-kot/homebox/backend/pkgs/hasher"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestAuthTokenRepo_CreateToken(t *testing.T) { func TestAuthTokenRepo_CreateToken(t *testing.T) {
asrt := assert.New(t)
ctx := context.Background() ctx := context.Background()
user := userFactory() user := userFactory()
userOut, err := tRepos.Users.Create(ctx, user) userOut, err := tRepos.Users.Create(ctx, user)
asrt.NoError(err) require.NoError(t, err)
expiresAt := time.Now().Add(time.Hour) expiresAt := time.Now().Add(time.Hour)
@ -27,23 +27,22 @@ func TestAuthTokenRepo_CreateToken(t *testing.T) {
UserID: userOut.ID, UserID: userOut.ID,
}) })
asrt.NoError(err) require.NoError(t, err)
asrt.Equal(userOut.ID, token.UserID) assert.Equal(t, userOut.ID, token.UserID)
asrt.Equal(expiresAt, token.ExpiresAt) assert.Equal(t, expiresAt, token.ExpiresAt)
// Cleanup // Cleanup
asrt.NoError(tRepos.Users.Delete(ctx, userOut.ID)) require.NoError(t, tRepos.Users.Delete(ctx, userOut.ID))
_, err = tRepos.AuthTokens.DeleteAll(ctx) _, err = tRepos.AuthTokens.DeleteAll(ctx)
asrt.NoError(err) require.NoError(t, err)
} }
func TestAuthTokenRepo_DeleteToken(t *testing.T) { func TestAuthTokenRepo_DeleteToken(t *testing.T) {
asrt := assert.New(t)
ctx := context.Background() ctx := context.Background()
user := userFactory() user := userFactory()
userOut, err := tRepos.Users.Create(ctx, user) userOut, err := tRepos.Users.Create(ctx, user)
asrt.NoError(err) require.NoError(t, err)
expiresAt := time.Now().Add(time.Hour) expiresAt := time.Now().Add(time.Hour)
@ -54,15 +53,14 @@ func TestAuthTokenRepo_DeleteToken(t *testing.T) {
ExpiresAt: expiresAt, ExpiresAt: expiresAt,
UserID: userOut.ID, UserID: userOut.ID,
}) })
asrt.NoError(err) require.NoError(t, err)
// Delete token // Delete token
err = tRepos.AuthTokens.DeleteToken(ctx, []byte(generatedToken.Raw)) err = tRepos.AuthTokens.DeleteToken(ctx, []byte(generatedToken.Raw))
asrt.NoError(err) require.NoError(t, err)
} }
func TestAuthTokenRepo_GetUserByToken(t *testing.T) { func TestAuthTokenRepo_GetUserByToken(t *testing.T) {
assert := assert.New(t)
ctx := context.Background() ctx := context.Background()
user := userFactory() user := userFactory()
@ -77,24 +75,23 @@ func TestAuthTokenRepo_GetUserByToken(t *testing.T) {
UserID: userOut.ID, UserID: userOut.ID,
}) })
assert.NoError(err) require.NoError(t, err)
// Get User from token // Get User from token
foundUser, err := tRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash) foundUser, err := tRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash)
assert.NoError(err) require.NoError(t, err)
assert.Equal(userOut.ID, foundUser.ID) assert.Equal(t, userOut.ID, foundUser.ID)
assert.Equal(userOut.Name, foundUser.Name) assert.Equal(t, userOut.Name, foundUser.Name)
assert.Equal(userOut.Email, foundUser.Email) assert.Equal(t, userOut.Email, foundUser.Email)
// Cleanup // Cleanup
assert.NoError(tRepos.Users.Delete(ctx, userOut.ID)) require.NoError(t, tRepos.Users.Delete(ctx, userOut.ID))
_, err = tRepos.AuthTokens.DeleteAll(ctx) _, err = tRepos.AuthTokens.DeleteAll(ctx)
assert.NoError(err) require.NoError(t, err)
} }
func TestAuthTokenRepo_PurgeExpiredTokens(t *testing.T) { func TestAuthTokenRepo_PurgeExpiredTokens(t *testing.T) {
assert := assert.New(t)
ctx := context.Background() ctx := context.Background()
user := userFactory() user := userFactory()
@ -112,27 +109,26 @@ func TestAuthTokenRepo_PurgeExpiredTokens(t *testing.T) {
UserID: userOut.ID, UserID: userOut.ID,
}) })
assert.NoError(err) require.NoError(t, err)
assert.NotNil(createdToken) assert.NotNil(t, createdToken)
createdTokens = append(createdTokens, createdToken) createdTokens = append(createdTokens, createdToken)
} }
// Purge expired tokens // Purge expired tokens
tokensDeleted, err := tRepos.AuthTokens.PurgeExpiredTokens(ctx) tokensDeleted, err := tRepos.AuthTokens.PurgeExpiredTokens(ctx)
assert.NoError(err) require.NoError(t, err)
assert.Equal(5, tokensDeleted) assert.Equal(t, 5, tokensDeleted)
// Check if tokens are deleted // Check if tokens are deleted
for _, token := range createdTokens { for _, token := range createdTokens {
_, err := tRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash) _, err := tRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash)
assert.Error(err) require.Error(t, err)
} }
// Cleanup // Cleanup
assert.NoError(tRepos.Users.Delete(ctx, userOut.ID)) require.NoError(t, tRepos.Users.Delete(ctx, userOut.ID))
_, err = tRepos.AuthTokens.DeleteAll(ctx) _, err = tRepos.AuthTokens.DeleteAll(ctx)
assert.NoError(err) require.NoError(t, err)
} }

View file

@ -60,32 +60,32 @@ func mapUserOut(user *ent.User) UserOut {
} }
} }
func (e *UserRepository) GetOneId(ctx context.Context, id uuid.UUID) (UserOut, error) { func (r *UserRepository) GetOneID(ctx context.Context, ID uuid.UUID) (UserOut, error) {
return mapUserOutErr(e.db.User.Query(). return mapUserOutErr(r.db.User.Query().
Where(user.ID(id)). Where(user.ID(ID)).
WithGroup(). WithGroup().
Only(ctx)) Only(ctx))
} }
func (e *UserRepository) GetOneEmail(ctx context.Context, email string) (UserOut, error) { func (r *UserRepository) GetOneEmail(ctx context.Context, email string) (UserOut, error) {
return mapUserOutErr(e.db.User.Query(). return mapUserOutErr(r.db.User.Query().
Where(user.EmailEqualFold(email)). Where(user.EmailEqualFold(email)).
WithGroup(). WithGroup().
Only(ctx), Only(ctx),
) )
} }
func (e *UserRepository) GetAll(ctx context.Context) ([]UserOut, error) { func (r *UserRepository) GetAll(ctx context.Context) ([]UserOut, error) {
return mapUsersOutErr(e.db.User.Query().WithGroup().All(ctx)) return mapUsersOutErr(r.db.User.Query().WithGroup().All(ctx))
} }
func (e *UserRepository) Create(ctx context.Context, usr UserCreate) (UserOut, error) { func (r *UserRepository) Create(ctx context.Context, usr UserCreate) (UserOut, error) {
role := user.RoleUser role := user.RoleUser
if usr.IsOwner { if usr.IsOwner {
role = user.RoleOwner role = user.RoleOwner
} }
entUser, err := e.db.User. entUser, err := r.db.User.
Create(). Create().
SetName(usr.Name). SetName(usr.Name).
SetEmail(usr.Email). SetEmail(usr.Email).
@ -98,11 +98,11 @@ func (e *UserRepository) Create(ctx context.Context, usr UserCreate) (UserOut, e
return UserOut{}, err return UserOut{}, err
} }
return e.GetOneId(ctx, entUser.ID) return r.GetOneID(ctx, entUser.ID)
} }
func (e *UserRepository) Update(ctx context.Context, ID uuid.UUID, data UserUpdate) error { func (r *UserRepository) Update(ctx context.Context, ID uuid.UUID, data UserUpdate) error {
q := e.db.User.Update(). q := r.db.User.Update().
Where(user.ID(ID)). Where(user.ID(ID)).
SetName(data.Name). SetName(data.Name).
SetEmail(data.Email) SetEmail(data.Email)
@ -111,18 +111,18 @@ func (e *UserRepository) Update(ctx context.Context, ID uuid.UUID, data UserUpda
return err return err
} }
func (e *UserRepository) Delete(ctx context.Context, id uuid.UUID) error { func (r *UserRepository) Delete(ctx context.Context, id uuid.UUID) error {
_, err := e.db.User.Delete().Where(user.ID(id)).Exec(ctx) _, err := r.db.User.Delete().Where(user.ID(id)).Exec(ctx)
return err return err
} }
func (e *UserRepository) DeleteAll(ctx context.Context) error { func (r *UserRepository) DeleteAll(ctx context.Context) error {
_, err := e.db.User.Delete().Exec(ctx) _, err := r.db.User.Delete().Exec(ctx)
return err return err
} }
func (e *UserRepository) GetSuperusers(ctx context.Context) ([]*ent.User, error) { func (r *UserRepository) GetSuperusers(ctx context.Context) ([]*ent.User, error) {
users, err := e.db.User.Query().Where(user.IsSuperuser(true)).All(ctx) users, err := r.db.User.Query().Where(user.IsSuperuser(true)).All(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func userFactory() UserCreate { func userFactory() UserCreate {
@ -23,18 +24,18 @@ func TestUserRepo_GetOneEmail(t *testing.T) {
ctx := context.Background() ctx := context.Background()
_, err := tRepos.Users.Create(ctx, user) _, err := tRepos.Users.Create(ctx, user)
assert.NoError(err) require.NoError(t, err)
foundUser, err := tRepos.Users.GetOneEmail(ctx, user.Email) foundUser, err := tRepos.Users.GetOneEmail(ctx, user.Email)
assert.NotNil(foundUser) assert.NotNil(foundUser)
assert.Nil(err) require.NoError(t, err)
assert.Equal(user.Email, foundUser.Email) assert.Equal(user.Email, foundUser.Email)
assert.Equal(user.Name, foundUser.Name) assert.Equal(user.Name, foundUser.Name)
// Cleanup // Cleanup
err = tRepos.Users.DeleteAll(ctx) err = tRepos.Users.DeleteAll(ctx)
assert.NoError(err) require.NoError(t, err)
} }
func TestUserRepo_GetOneId(t *testing.T) { func TestUserRepo_GetOneId(t *testing.T) {
@ -43,16 +44,16 @@ func TestUserRepo_GetOneId(t *testing.T) {
ctx := context.Background() ctx := context.Background()
userOut, _ := tRepos.Users.Create(ctx, user) userOut, _ := tRepos.Users.Create(ctx, user)
foundUser, err := tRepos.Users.GetOneId(ctx, userOut.ID) foundUser, err := tRepos.Users.GetOneID(ctx, userOut.ID)
assert.NotNil(foundUser) assert.NotNil(foundUser)
assert.Nil(err) require.NoError(t, err)
assert.Equal(user.Email, foundUser.Email) assert.Equal(user.Email, foundUser.Email)
assert.Equal(user.Name, foundUser.Name) assert.Equal(user.Name, foundUser.Name)
// Cleanup // Cleanup
err = tRepos.Users.DeleteAll(ctx) err = tRepos.Users.DeleteAll(ctx)
assert.NoError(err) require.NoError(t, err)
} }
func TestUserRepo_GetAll(t *testing.T) { func TestUserRepo_GetAll(t *testing.T) {
@ -76,7 +77,7 @@ func TestUserRepo_GetAll(t *testing.T) {
// Validate // Validate
allUsers, err := tRepos.Users.GetAll(ctx) allUsers, err := tRepos.Users.GetAll(ctx)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(created), len(allUsers)) assert.Equal(t, len(created), len(allUsers))
for _, usr := range created { for _, usr := range created {
@ -96,12 +97,12 @@ func TestUserRepo_GetAll(t *testing.T) {
// Cleanup // Cleanup
err = tRepos.Users.DeleteAll(ctx) err = tRepos.Users.DeleteAll(ctx)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestUserRepo_Update(t *testing.T) { func TestUserRepo_Update(t *testing.T) {
user, err := tRepos.Users.Create(context.Background(), userFactory()) user, err := tRepos.Users.Create(context.Background(), userFactory())
assert.NoError(t, err) require.NoError(t, err)
updateData := UserUpdate{ updateData := UserUpdate{
Name: fk.Str(10), Name: fk.Str(10),
@ -110,11 +111,11 @@ func TestUserRepo_Update(t *testing.T) {
// Update // Update
err = tRepos.Users.Update(context.Background(), user.ID, updateData) err = tRepos.Users.Update(context.Background(), user.ID, updateData)
assert.NoError(t, err) require.NoError(t, err)
// Validate // Validate
updated, err := tRepos.Users.GetOneId(context.Background(), user.ID) updated, err := tRepos.Users.GetOneID(context.Background(), user.ID)
assert.NoError(t, err) require.NoError(t, err)
assert.NotEqual(t, user.Name, updated.Name) assert.NotEqual(t, user.Name, updated.Name)
assert.NotEqual(t, user.Email, updated.Email) assert.NotEqual(t, user.Email, updated.Email)
} }
@ -131,12 +132,12 @@ func TestUserRepo_Delete(t *testing.T) {
ctx := context.Background() ctx := context.Background()
allUsers, _ := tRepos.Users.GetAll(ctx) allUsers, _ := tRepos.Users.GetAll(ctx)
assert.Greater(t, len(allUsers), 0) assert.NotEmpty(t, allUsers)
err := tRepos.Users.DeleteAll(ctx) err := tRepos.Users.DeleteAll(ctx)
assert.NoError(t, err) require.NoError(t, err)
allUsers, _ = tRepos.Users.GetAll(ctx) allUsers, _ = tRepos.Users.GetAll(ctx)
assert.Equal(t, len(allUsers), 0) assert.Empty(t, allUsers)
} }
func TestUserRepo_GetSuperusers(t *testing.T) { func TestUserRepo_GetSuperusers(t *testing.T) {
@ -160,7 +161,7 @@ func TestUserRepo_GetSuperusers(t *testing.T) {
ctx := context.Background() ctx := context.Background()
superUsers, err := tRepos.Users.GetSuperusers(ctx) superUsers, err := tRepos.Users.GetSuperusers(ctx)
assert.NoError(t, err) require.NoError(t, err)
for _, usr := range superUsers { for _, usr := range superUsers {
assert.True(t, usr.IsSuperuser) assert.True(t, usr.IsSuperuser)
@ -168,5 +169,5 @@ func TestUserRepo_GetSuperusers(t *testing.T) {
// Cleanup // Cleanup
err = tRepos.Users.DeleteAll(ctx) err = tRepos.Users.DeleteAll(ctx)
assert.NoError(t, err) require.NoError(t, err)
} }

View file

@ -1,3 +1,4 @@
// Package repo provides the data access layer for the application.
package repo package repo
import ( import (

View file

@ -1,3 +1,4 @@
// Package types provides custom types for the application.
package types package types
import ( import (

View file

@ -1,3 +1,4 @@
// Package config provides the configuration for the application.
package config package config
import ( import (

View file

@ -7,5 +7,5 @@ const (
type Storage struct { type Storage struct {
// Data is the path to the root directory // Data is the path to the root directory
Data string `yaml:"data" conf:"default:./.data"` Data string `yaml:"data" conf:"default:./.data"`
SqliteUrl string `yaml:"sqlite-url" conf:"default:./.data/homebox.db?_pragma=busy_timeout=1000&_pragma=journal_mode=WAL&_fk=1"` SqliteURL string `yaml:"sqlite-url" conf:"default:./.data/homebox.db?_pragma=busy_timeout=999&_pragma=journal_mode=WAL&_fk=1"`
} }

View file

@ -1,3 +1,4 @@
// Package validate provides a wrapper around the go-playground/validator package
package validate package validate
import ( import (
@ -8,7 +9,7 @@ import (
var validate *validator.Validate var validate *validator.Validate
func init() { func init() { // nolint
validate = validator.New() validate = validator.New()
err := validate.RegisterValidation("shoutrrr", func(fl validator.FieldLevel) bool { err := validate.RegisterValidation("shoutrrr", func(fl validator.FieldLevel) bool {
@ -52,17 +53,16 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
// Checks a struct for validation errors and returns any errors the occur. This // Check a struct for validation errors and returns any errors the occur. This
// wraps the validate.Struct() function and provides some error wrapping. When // wraps the validate.Struct() function and provides some error wrapping. When
// a validator.ValidationErrors is returned, it is wrapped transformed into a // a validator.ValidationErrors is returned, it is wrapped transformed into a
// FieldErrors array and returned. // FieldErrors array and returned.
func Check(val any) error { func Check(val any) error {
err := validate.Struct(val) err := validate.Struct(val)
if err != nil { if err != nil {
verrors, ok := err.(validator.ValidationErrors) verrors, ok := err.(validator.ValidationErrors) // nolint - we know it's a validator.ValidationErrors
if !ok { if !ok {
return err return err
} }

View file

@ -0,0 +1,2 @@
// Package mid provides web middleware.
package mid

View file

@ -44,7 +44,7 @@ func Errors(svr *server.Server, log zerolog.Logger) errchain.ErrorHandler {
case validate.IsFieldError(err): case validate.IsFieldError(err):
code = http.StatusUnprocessableEntity code = http.StatusUnprocessableEntity
fieldErrors := err.(validate.FieldErrors) fieldErrors := err.(validate.FieldErrors) // nolint
resp.Error = "Validation Error" resp.Error = "Validation Error"
resp.Fields = map[string]string{} resp.Fields = map[string]string{}
@ -52,7 +52,7 @@ func Errors(svr *server.Server, log zerolog.Logger) errchain.ErrorHandler {
resp.Fields[fieldError.Field] = fieldError.Error resp.Fields[fieldError.Field] = fieldError.Error
} }
case validate.IsRequestError(err): case validate.IsRequestError(err):
requestError := err.(*validate.RequestError) requestError := err.(*validate.RequestError) // nolint
resp.Error = requestError.Error() resp.Error = requestError.Error()
if requestError.Status == 0 { if requestError.Status == 0 {

View file

@ -1,4 +1,4 @@
// sqlite package provides a CGO free implementation of the sqlite3 driver. This wraps the // Package cgofreesqlite package provides a CGO free implementation of the sqlite3 driver. This wraps the
// modernc.org/sqlite driver and adds the PRAGMA foreign_keys = ON; statement to the connection // modernc.org/sqlite driver and adds the PRAGMA foreign_keys = ON; statement to the connection
// initialization as well as registering the driver with the sql package as "sqlite3" for compatibility // initialization as well as registering the driver with the sql package as "sqlite3" for compatibility
// with entgo.io // with entgo.io
@ -35,6 +35,6 @@ func (d CGOFreeSqliteDriver) Open(name string) (conn driver.Conn, err error) {
return conn, err return conn, err
} }
func init() { func init() { //nolint:gochecknoinits
sql.Register("sqlite3", CGOFreeSqliteDriver{Driver: &sqlite.Driver{}}) sql.Register("sqlite3", CGOFreeSqliteDriver{Driver: &sqlite.Driver{}})
} }

View file

@ -1,3 +1,4 @@
// Package faker provides a simple interface for generating fake data for testing.
package faker package faker
import ( import (

View file

@ -0,0 +1,2 @@
// Package hasher provides a simple interface for hashing and verifying passwords.
package hasher

View file

@ -9,7 +9,7 @@ import (
var enabled = true var enabled = true
func init() { func init() { // nolint: gochecknoinits
disableHas := os.Getenv("UNSAFE_DISABLE_PASSWORD_PROJECTION") == "yes_i_am_sure" disableHas := os.Getenv("UNSAFE_DISABLE_PASSWORD_PROJECTION") == "yes_i_am_sure"
if disableHas { if disableHas {

View file

@ -1,3 +1,4 @@
// Package mailer provides a simple mailer for sending emails.
package mailer package mailer
import ( import (

View file

@ -5,7 +5,7 @@ import (
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
const ( const (
@ -59,5 +59,5 @@ func Test_Mailer(t *testing.T) {
err = mailer.Send(msg) err = mailer.Send(msg)
assert.Nil(t, err) require.NoError(t, err)
} }

View file

@ -1,3 +1,4 @@
// Package pathlib provides a way to safely create a file path without overwriting any existing files.
package pathlib package pathlib
import ( import (
@ -14,7 +15,7 @@ var dirReader dirReaderFunc = func(directory string) []string {
if err != nil { if err != nil {
return nil return nil
} }
defer f.Close() defer func() { _ = f.Close() }()
names, err := f.Readdirnames(-1) names, err := f.Readdirnames(-1)
if err != nil { if err != nil {

View file

@ -1,3 +1,4 @@
// Package set provides a simple set implementation.
package set package set
type key interface { type key interface {