diff --git a/Taskfile.yml b/Taskfile.yml index 71d38c7..ad23351 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -3,7 +3,10 @@ version: "3" tasks: generate: cmds: - - cd backend && go generate ./... + - | + cd backend && ent generate ./ent/schema \ + --template=ent/schema/templates/stringer.tmpl \ + --template=ent/schema/templates/has_id.tmpl - cd backend/app/api/ && swag fmt - cd backend/app/api/ && swag init --dir=./,../../internal,../../pkgs - | diff --git a/backend/internal/repo/id_set.go b/backend/internal/repo/id_set.go new file mode 100644 index 0000000..0041d93 --- /dev/null +++ b/backend/internal/repo/id_set.go @@ -0,0 +1,62 @@ +package repo + +import "github.com/google/uuid" + +// HasID is an interface to entities that have an ID uuid.UUID field and a GetID() method. +// This interface is fulfilled by all entities generated by entgo.io/ent via a custom template +type HasID interface { + GetID() uuid.UUID +} + +// IDSet is a utility set-like type for working with sets of uuid.UUIDs within a repository +// instance. Most useful for comparing lists of UUIDs for processing relationship +// IDs and remove/adding relationships as required. +// +// # See how ItemRepo uses it to manage the Labels-To-Items relationship +// +// NOTE: may be worth moving this to a more generic package/set implementation +// or use a 3rd party set library, but this is good enough for now +type IDSet struct { + mp map[uuid.UUID]struct{} +} + +func NewIDSet(l int) *IDSet { + return &IDSet{ + mp: make(map[uuid.UUID]struct{}, l), + } +} + +func EntitiesToIDSet[T HasID](entities []T) *IDSet { + s := NewIDSet(len(entities)) + for _, e := range entities { + s.Add(e.GetID()) + } + return s +} + +func (t *IDSet) Slice() []uuid.UUID { + s := make([]uuid.UUID, 0, len(t.mp)) + for k := range t.mp { + s = append(s, k) + } + return s +} + +func (t *IDSet) Add(ids ...uuid.UUID) { + for _, id := range ids { + t.mp[id] = struct{}{} + } +} + +func (t *IDSet) Has(id uuid.UUID) bool { + _, ok := t.mp[id] + return ok +} + +func (t *IDSet) Len() int { + return len(t.mp) +} + +func (t *IDSet) Remove(id uuid.UUID) { + delete(t.mp, id) +} diff --git a/backend/internal/repo/repo_items.go b/backend/internal/repo/repo_items.go index a1c90df..0d2e0e5 100644 --- a/backend/internal/repo/repo_items.go +++ b/backend/internal/repo/repo_items.go @@ -27,6 +27,7 @@ func (e *ItemsRepository) GetOne(ctx context.Context, id uuid.UUID) (*ent.Item, Only(ctx) } +// GetAll returns all the items in the database with the Labels and Locations eager loaded. func (e *ItemsRepository) GetAll(ctx context.Context, gid uuid.UUID) ([]*ent.Item, error) { return e.db.Item.Query(). Where(item.HasGroupWith(group.ID(gid))). @@ -78,8 +79,26 @@ func (e *ItemsRepository) Update(ctx context.Context, data types.ItemUpdate) (*e SetWarrantyExpires(data.WarrantyExpires). SetWarrantyDetails(data.WarrantyDetails) - err := q.Exec(ctx) + currentLabels, err := e.db.Item.Query().Where(item.ID(data.ID)).QueryLabel().All(ctx) + if err != nil { + return nil, err + } + set := EntitiesToIDSet(currentLabels) + + for _, l := range data.LabelIDs { + if set.Has(l) { + set.Remove(l) + continue + } + q.AddLabelIDs(l) + } + + if set.Len() > 0 { + q.RemoveLabelIDs(set.Slice()...) + } + + err = q.Exec(ctx) if err != nil { return nil, err } diff --git a/backend/internal/repo/repo_items_test.go b/backend/internal/repo/repo_items_test.go index 625ef6c..893bcb3 100644 --- a/backend/internal/repo/repo_items_test.go +++ b/backend/internal/repo/repo_items_test.go @@ -5,6 +5,7 @@ import ( "testing" "time" + "github.com/google/uuid" "github.com/hay-kot/content/backend/ent" "github.com/hay-kot/content/backend/internal/types" "github.com/stretchr/testify/assert" @@ -130,6 +131,66 @@ func TestItemsRepository_Delete(t *testing.T) { assert.Empty(t, results) } +func TestItemsRepository_Update_Labels(t *testing.T) { + entity := useItems(t, 1)[0] + labels := useLabels(t, 3) + + labelsIDs := []uuid.UUID{labels[0].ID, labels[1].ID, labels[2].ID} + + type args struct { + labelIds []uuid.UUID + } + + tests := []struct { + name string + args args + want []uuid.UUID + }{ + { + name: "add all labels", + args: args{ + labelIds: labelsIDs, + }, + want: labelsIDs, + }, + { + name: "update with one label", + args: args{ + labelIds: labelsIDs[:1], + }, + want: labelsIDs[:1], + }, + { + name: "add one new label to existing single label", + args: args{ + labelIds: labelsIDs[1:], + }, + want: labelsIDs[1:], + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Apply all labels to entity + updateData := types.ItemUpdate{ + ID: entity.ID, + Name: entity.Name, + LocationID: entity.Edges.Location.ID, + LabelIDs: tt.args.labelIds, + } + + updated, err := tRepos.Items.Update(context.Background(), updateData) + assert.NoError(t, err) + assert.Len(t, tt.want, len(updated.Edges.Label)) + + for _, label := range updated.Edges.Label { + assert.Contains(t, tt.want, label.ID) + } + }) + } + +} + func TestItemsRepository_Update(t *testing.T) { entities := useItems(t, 3) diff --git a/backend/internal/repo/repo_labels_test.go b/backend/internal/repo/repo_labels_test.go index 81a6f68..137376a 100644 --- a/backend/internal/repo/repo_labels_test.go +++ b/backend/internal/repo/repo_labels_test.go @@ -16,7 +16,7 @@ func labelFactory() types.LabelCreate { } } -func useLabels(t *testing.T, len int) ([]*ent.Label, func()) { +func useLabels(t *testing.T, len int) []*ent.Label { t.Helper() labels := make([]*ent.Label, len) @@ -28,17 +28,17 @@ func useLabels(t *testing.T, len int) ([]*ent.Label, func()) { labels[i] = item } - return labels, func() { + t.Cleanup(func() { for _, item := range labels { - err := tRepos.Labels.Delete(context.Background(), item.ID) - assert.NoError(t, err) + _ = tRepos.Labels.Delete(context.Background(), item.ID) } - } + }) + + return labels } func TestLabelRepository_Get(t *testing.T) { - labels, cleanup := useLabels(t, 1) - defer cleanup() + labels := useLabels(t, 1) label := labels[0] // Get by ID @@ -48,8 +48,7 @@ func TestLabelRepository_Get(t *testing.T) { } func TestLabelRepositoryGetAll(t *testing.T) { - _, cleanup := useLabels(t, 10) - defer cleanup() + useLabels(t, 10) all, err := tRepos.Labels.GetAll(context.Background(), tGroup.ID) assert.NoError(t, err)