fix: conditionally filter parent locations (#133)

This commit is contained in:
Hayden 2022-11-02 11:54:43 -08:00 committed by GitHub
parent fbcbde836a
commit 8e1947d971
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 135 additions and 67 deletions

View file

@ -23,7 +23,7 @@ type ItemService struct {
at attachmentTokens
}
func (svc *ItemService) CsvImport(ctx context.Context, gid uuid.UUID, data [][]string) (int, error) {
func (svc *ItemService) CsvImport(ctx context.Context, GID uuid.UUID, data [][]string) (int, error) {
loaded := []csvRow{}
// Skip first row
@ -66,7 +66,7 @@ func (svc *ItemService) CsvImport(ctx context.Context, gid uuid.UUID, data [][]s
// Bootstrap the locations and labels so we can reuse the created IDs for the items
locations := map[string]uuid.UUID{}
existingLocation, err := svc.repo.Locations.GetAll(ctx, gid)
existingLocation, err := svc.repo.Locations.GetAll(ctx, GID, repo.LocationQuery{})
if err != nil {
return 0, err
}
@ -75,7 +75,7 @@ func (svc *ItemService) CsvImport(ctx context.Context, gid uuid.UUID, data [][]s
}
labels := map[string]uuid.UUID{}
existingLabels, err := svc.repo.Labels.GetAll(ctx, gid)
existingLabels, err := svc.repo.Labels.GetAll(ctx, GID)
if err != nil {
return 0, err
}
@ -87,7 +87,7 @@ func (svc *ItemService) CsvImport(ctx context.Context, gid uuid.UUID, data [][]s
// Locations
if _, exists := locations[row.Location]; !exists {
result, err := svc.repo.Locations.Create(ctx, gid, repo.LocationCreate{
result, err := svc.repo.Locations.Create(ctx, GID, repo.LocationCreate{
Name: row.Location,
Description: "",
})
@ -103,7 +103,7 @@ func (svc *ItemService) CsvImport(ctx context.Context, gid uuid.UUID, data [][]s
if _, exists := labels[label]; exists {
continue
}
result, err := svc.repo.Labels.Create(ctx, gid, repo.LabelCreate{
result, err := svc.repo.Labels.Create(ctx, GID, repo.LabelCreate{
Name: label,
Description: "",
})
@ -119,7 +119,7 @@ func (svc *ItemService) CsvImport(ctx context.Context, gid uuid.UUID, data [][]s
for _, row := range loaded {
// Check Import Ref
if row.Item.ImportRef != "" {
exists, err := svc.repo.Items.CheckRef(ctx, gid, row.Item.ImportRef)
exists, err := svc.repo.Items.CheckRef(ctx, GID, row.Item.ImportRef)
if exists {
continue
}
@ -139,7 +139,7 @@ func (svc *ItemService) CsvImport(ctx context.Context, gid uuid.UUID, data [][]s
Str("location", row.Location).
Msgf("Creating Item: %s", row.Item.Name)
result, err := svc.repo.Items.Create(ctx, gid, repo.ItemCreate{
result, err := svc.repo.Items.Create(ctx, GID, repo.ItemCreate{
ImportRef: row.Item.ImportRef,
Name: row.Item.Name,
Description: row.Item.Description,
@ -152,7 +152,7 @@ func (svc *ItemService) CsvImport(ctx context.Context, gid uuid.UUID, data [][]s
}
// Update the item with the rest of the data
_, err = svc.repo.Items.UpdateByGroup(ctx, gid, repo.ItemUpdate{
_, err = svc.repo.Items.UpdateByGroup(ctx, GID, repo.ItemUpdate{
// Edges
LocationID: locationID,
LabelIDs: labelIDs,

View file

@ -5,6 +5,7 @@ import (
"testing"
"github.com/google/uuid"
"github.com/hay-kot/homebox/backend/internal/data/repo"
"github.com/stretchr/testify/assert"
)
@ -38,7 +39,7 @@ func TestItemService_CsvImport(t *testing.T) {
dataCsv = append(dataCsv, newCsvRow(item))
}
allLocation, err := tRepos.Locations.GetAll(context.Background(), tGroup.ID)
allLocation, err := tRepos.Locations.GetAll(context.Background(), tGroup.ID, repo.LocationQuery{})
assert.NoError(t, err)
locNames := []string{}
for _, loc := range allLocation {

View file

@ -2,6 +2,7 @@ package repo
import (
"context"
"strings"
"time"
"github.com/google/uuid"
@ -90,8 +91,12 @@ func mapLocationOut(location *ent.Location) LocationOut {
}
}
type LocationQuery struct {
FilterChildren bool `json:"filterChildren"`
}
// GetALlWithCount returns all locations with item count field populated
func (r *LocationRepository) GetAll(ctx context.Context, groupId uuid.UUID) ([]LocationOutCount, error) {
func (r *LocationRepository) GetAll(ctx context.Context, GID uuid.UUID, filter LocationQuery) ([]LocationOutCount, error) {
query := `--sql
SELECT
id,
@ -111,13 +116,18 @@ func (r *LocationRepository) GetAll(ctx context.Context, groupId uuid.UUID) ([]L
FROM
locations
WHERE
locations.group_locations = ?
AND locations.location_children IS NULL
locations.group_locations = ? {{ FILTER_CHILDREN }}
ORDER BY
locations.name ASC
`
rows, err := r.db.Sql().QueryContext(ctx, query, groupId)
if filter.FilterChildren {
query = strings.Replace(query, "{{ FILTER_CHILDREN }}", "AND locations.location_children IS NULL", 1)
} else {
query = strings.Replace(query, "{{ FILTER_CHILDREN }}", "", 1)
}
rows, err := r.db.Sql().QueryContext(ctx, query, GID)
if err != nil {
return nil, err
}

View file

@ -43,7 +43,7 @@ func TestLocationRepositoryGetAllWithCount(t *testing.T) {
assert.NoError(t, err)
results, err := tRepos.Locations.GetAll(context.Background(), tGroup.ID)
results, err := tRepos.Locations.GetAll(context.Background(), tGroup.ID, LocationQuery{})
assert.NoError(t, err)
for _, loc := range results {