implement label updates for items

This commit is contained in:
Hayden 2022-09-11 17:50:13 -08:00
parent 50e6d353dc
commit 15a610fa8b
5 changed files with 155 additions and 11 deletions

View file

@ -3,7 +3,10 @@ version: "3"
tasks:
generate:
cmds:
- cd backend && go generate ./...
- |
cd backend && ent generate ./ent/schema \
--template=ent/schema/templates/stringer.tmpl \
--template=ent/schema/templates/has_id.tmpl
- cd backend/app/api/ && swag fmt
- cd backend/app/api/ && swag init --dir=./,../../internal,../../pkgs
- |

View file

@ -0,0 +1,62 @@
package repo
import "github.com/google/uuid"
// HasID is an interface to entities that have an ID uuid.UUID field and a GetID() method.
// This interface is fulfilled by all entities generated by entgo.io/ent via a custom template
type HasID interface {
GetID() uuid.UUID
}
// IDSet is a utility set-like type for working with sets of uuid.UUIDs within a repository
// instance. Most useful for comparing lists of UUIDs for processing relationship
// IDs and remove/adding relationships as required.
//
// # See how ItemRepo uses it to manage the Labels-To-Items relationship
//
// NOTE: may be worth moving this to a more generic package/set implementation
// or use a 3rd party set library, but this is good enough for now
type IDSet struct {
mp map[uuid.UUID]struct{}
}
func NewIDSet(l int) *IDSet {
return &IDSet{
mp: make(map[uuid.UUID]struct{}, l),
}
}
func EntitiesToIDSet[T HasID](entities []T) *IDSet {
s := NewIDSet(len(entities))
for _, e := range entities {
s.Add(e.GetID())
}
return s
}
func (t *IDSet) Slice() []uuid.UUID {
s := make([]uuid.UUID, 0, len(t.mp))
for k := range t.mp {
s = append(s, k)
}
return s
}
func (t *IDSet) Add(ids ...uuid.UUID) {
for _, id := range ids {
t.mp[id] = struct{}{}
}
}
func (t *IDSet) Has(id uuid.UUID) bool {
_, ok := t.mp[id]
return ok
}
func (t *IDSet) Len() int {
return len(t.mp)
}
func (t *IDSet) Remove(id uuid.UUID) {
delete(t.mp, id)
}

View file

@ -27,6 +27,7 @@ func (e *ItemsRepository) GetOne(ctx context.Context, id uuid.UUID) (*ent.Item,
Only(ctx)
}
// GetAll returns all the items in the database with the Labels and Locations eager loaded.
func (e *ItemsRepository) GetAll(ctx context.Context, gid uuid.UUID) ([]*ent.Item, error) {
return e.db.Item.Query().
Where(item.HasGroupWith(group.ID(gid))).
@ -78,8 +79,26 @@ func (e *ItemsRepository) Update(ctx context.Context, data types.ItemUpdate) (*e
SetWarrantyExpires(data.WarrantyExpires).
SetWarrantyDetails(data.WarrantyDetails)
err := q.Exec(ctx)
currentLabels, err := e.db.Item.Query().Where(item.ID(data.ID)).QueryLabel().All(ctx)
if err != nil {
return nil, err
}
set := EntitiesToIDSet(currentLabels)
for _, l := range data.LabelIDs {
if set.Has(l) {
set.Remove(l)
continue
}
q.AddLabelIDs(l)
}
if set.Len() > 0 {
q.RemoveLabelIDs(set.Slice()...)
}
err = q.Exec(ctx)
if err != nil {
return nil, err
}

View file

@ -5,6 +5,7 @@ import (
"testing"
"time"
"github.com/google/uuid"
"github.com/hay-kot/content/backend/ent"
"github.com/hay-kot/content/backend/internal/types"
"github.com/stretchr/testify/assert"
@ -130,6 +131,66 @@ func TestItemsRepository_Delete(t *testing.T) {
assert.Empty(t, results)
}
func TestItemsRepository_Update_Labels(t *testing.T) {
entity := useItems(t, 1)[0]
labels := useLabels(t, 3)
labelsIDs := []uuid.UUID{labels[0].ID, labels[1].ID, labels[2].ID}
type args struct {
labelIds []uuid.UUID
}
tests := []struct {
name string
args args
want []uuid.UUID
}{
{
name: "add all labels",
args: args{
labelIds: labelsIDs,
},
want: labelsIDs,
},
{
name: "update with one label",
args: args{
labelIds: labelsIDs[:1],
},
want: labelsIDs[:1],
},
{
name: "add one new label to existing single label",
args: args{
labelIds: labelsIDs[1:],
},
want: labelsIDs[1:],
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Apply all labels to entity
updateData := types.ItemUpdate{
ID: entity.ID,
Name: entity.Name,
LocationID: entity.Edges.Location.ID,
LabelIDs: tt.args.labelIds,
}
updated, err := tRepos.Items.Update(context.Background(), updateData)
assert.NoError(t, err)
assert.Len(t, tt.want, len(updated.Edges.Label))
for _, label := range updated.Edges.Label {
assert.Contains(t, tt.want, label.ID)
}
})
}
}
func TestItemsRepository_Update(t *testing.T) {
entities := useItems(t, 3)

View file

@ -16,7 +16,7 @@ func labelFactory() types.LabelCreate {
}
}
func useLabels(t *testing.T, len int) ([]*ent.Label, func()) {
func useLabels(t *testing.T, len int) []*ent.Label {
t.Helper()
labels := make([]*ent.Label, len)
@ -28,17 +28,17 @@ func useLabels(t *testing.T, len int) ([]*ent.Label, func()) {
labels[i] = item
}
return labels, func() {
t.Cleanup(func() {
for _, item := range labels {
err := tRepos.Labels.Delete(context.Background(), item.ID)
assert.NoError(t, err)
_ = tRepos.Labels.Delete(context.Background(), item.ID)
}
}
})
return labels
}
func TestLabelRepository_Get(t *testing.T) {
labels, cleanup := useLabels(t, 1)
defer cleanup()
labels := useLabels(t, 1)
label := labels[0]
// Get by ID
@ -48,8 +48,7 @@ func TestLabelRepository_Get(t *testing.T) {
}
func TestLabelRepositoryGetAll(t *testing.T) {
_, cleanup := useLabels(t, 10)
defer cleanup()
useLabels(t, 10)
all, err := tRepos.Labels.GetAll(context.Background(), tGroup.ID)
assert.NoError(t, err)