From d01c85e75b889dd2088ea8bb9adf2e08b2b6bfbd Mon Sep 17 00:00:00 2001 From: Hayden <64056131+hay-kot@users.noreply.github.com> Date: Tue, 27 Sep 2022 15:44:42 -0800 Subject: [PATCH] refactor sets to own package --- backend/internal/repo/id_set.go | 57 +----- backend/internal/repo/repo_items.go | 4 +- backend/pkgs/set/funcs.go | 101 ++++++++++ backend/pkgs/set/funcs_test.go | 287 ++++++++++++++++++++++++++++ backend/pkgs/set/set.go | 56 ++++++ backend/pkgs/set/set_test.go | 255 ++++++++++++++++++++++++ 6 files changed, 709 insertions(+), 51 deletions(-) create mode 100644 backend/pkgs/set/funcs.go create mode 100644 backend/pkgs/set/funcs_test.go create mode 100644 backend/pkgs/set/set.go create mode 100644 backend/pkgs/set/set_test.go diff --git a/backend/internal/repo/id_set.go b/backend/internal/repo/id_set.go index dcb92ea..a39ac40 100644 --- a/backend/internal/repo/id_set.go +++ b/backend/internal/repo/id_set.go @@ -1,6 +1,9 @@ package repo -import "github.com/google/uuid" +import ( + "github.com/google/uuid" + "github.com/hay-kot/homebox/backend/pkgs/set" +) // HasID is an interface to entities that have an ID uuid.UUID field and a GetID() method. // This interface is fulfilled by all entities generated by entgo.io/ent via a custom template @@ -8,55 +11,11 @@ type HasID interface { GetID() uuid.UUID } -// IDSet is a utility set-like type for working with sets of uuid.UUIDs within a repository -// instance. Most useful for comparing lists of UUIDs for processing relationship -// IDs and remove/adding relationships as required. -// -// # See how ItemRepo uses it to manage the Labels-To-Items relationship -// -// NOTE: may be worth moving this to a more generic package/set implementation -// or use a 3rd party set library, but this is good enough for now -type IDSet struct { - mp map[uuid.UUID]struct{} -} - -func NewIDSet(l int) *IDSet { - return &IDSet{ - mp: make(map[uuid.UUID]struct{}, l), - } -} - -func entToIDSet[T HasID](entities []T) *IDSet { - s := NewIDSet(len(entities)) +func newIDSet[T HasID](entities []T) set.Set[uuid.UUID] { + uuids := make([]uuid.UUID, 0, len(entities)) for _, e := range entities { - s.Add(e.GetID()) + uuids = append(uuids, e.GetID()) } - return s -} -func (t *IDSet) Slice() []uuid.UUID { - s := make([]uuid.UUID, 0, len(t.mp)) - for k := range t.mp { - s = append(s, k) - } - return s -} - -func (t *IDSet) Add(ids ...uuid.UUID) { - for _, id := range ids { - t.mp[id] = struct{}{} - } -} - -func (t *IDSet) Has(id uuid.UUID) bool { - _, ok := t.mp[id] - return ok -} - -func (t *IDSet) Len() int { - return len(t.mp) -} - -func (t *IDSet) Remove(id uuid.UUID) { - delete(t.mp, id) + return set.New(uuids...) } diff --git a/backend/internal/repo/repo_items.go b/backend/internal/repo/repo_items.go index c916596..e6375f7 100644 --- a/backend/internal/repo/repo_items.go +++ b/backend/internal/repo/repo_items.go @@ -275,10 +275,10 @@ func (e *ItemsRepository) UpdateByGroup(ctx context.Context, gid uuid.UUID, data return ItemOut{}, err } - set := entToIDSet(currentLabels) + set := newIDSet(currentLabels) for _, l := range data.LabelIDs { - if set.Has(l) { + if set.Contains(l) { set.Remove(l) continue } diff --git a/backend/pkgs/set/funcs.go b/backend/pkgs/set/funcs.go new file mode 100644 index 0000000..0d9a261 --- /dev/null +++ b/backend/pkgs/set/funcs.go @@ -0,0 +1,101 @@ +package set + +// Diff returns the difference between two sets +func Diff[T key](a, b Set[T]) Set[T] { + s := New[T]() + for k := range a.mp { + if !b.Contains(k) { + s.Insert(k) + } + } + return s +} + +// Intersect returns the intersection between two sets +func Intersect[T key](a, b Set[T]) Set[T] { + s := New[T]() + for k := range a.mp { + if b.Contains(k) { + s.Insert(k) + } + } + return s +} + +// Union returns the union between two sets +func Union[T key](a, b Set[T]) Set[T] { + s := New[T]() + for k := range a.mp { + s.Insert(k) + } + for k := range b.mp { + s.Insert(k) + } + return s +} + +// Xor returns the symmetric difference between two sets +func Xor[T key](a, b Set[T]) Set[T] { + s := New[T]() + for k := range a.mp { + if !b.Contains(k) { + s.Insert(k) + } + } + for k := range b.mp { + if !a.Contains(k) { + s.Insert(k) + } + } + return s +} + +// Equal returns true if two sets are equal +func Equal[T key](a, b Set[T]) bool { + if a.Len() != b.Len() { + return false + } + for k := range a.mp { + if !b.Contains(k) { + return false + } + } + return true +} + +// Subset returns true if a is a subset of b +func Subset[T key](a, b Set[T]) bool { + if a.Len() > b.Len() { + return false + } + for k := range a.mp { + if !b.Contains(k) { + return false + } + } + return true +} + +// Superset returns true if a is a superset of b +func Superset[T key](a, b Set[T]) bool { + if a.Len() < b.Len() { + return false + } + for k := range b.mp { + if !a.Contains(k) { + return false + } + } + return true +} + +// Disjoint returns true if two sets are disjoint +func Disjoint[T key](a, b Set[T]) bool { + for k := range a.mp { + if b.Contains(k) { + return false + } + } + return true + +} diff --git a/backend/pkgs/set/funcs_test.go b/backend/pkgs/set/funcs_test.go new file mode 100644 index 0000000..ab3aa0e --- /dev/null +++ b/backend/pkgs/set/funcs_test.go @@ -0,0 +1,287 @@ +package set + +import ( + "reflect" + "testing" +) + +type args struct { + a Set[string] + b Set[string] +} + +var ( + argsBasic = args{ + a: New("a", "b", "c"), + b: New("b", "c", "d"), + } + + argsNoOverlap = args{ + a: New("a", "b", "c"), + b: New("d", "e", "f"), + } + + argsIdentical = args{ + a: New("a", "b", "c"), + b: New("a", "b", "c"), + } +) + +func TestDiff(t *testing.T) { + + tests := []struct { + name string + args args + want Set[string] + }{ + { + name: "diff basic", + args: argsBasic, + want: New("a"), + }, + { + name: "diff empty", + args: argsIdentical, + want: New[string](), + }, + { + name: "diff no overlap", + args: argsNoOverlap, + want: New("a", "b", "c"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Diff(tt.args.a, tt.args.b); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Diff() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestIntersect(t *testing.T) { + tests := []struct { + name string + args args + want Set[string] + }{ + { + name: "intersect basic", + args: argsBasic, + want: New("b", "c"), + }, + { + name: "identical sets", + args: argsIdentical, + want: New("a", "b", "c"), + }, + { + name: "no overlap", + args: argsNoOverlap, + want: New[string](), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Intersect(tt.args.a, tt.args.b); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Intersect() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnion(t *testing.T) { + tests := []struct { + name string + args args + want Set[string] + }{ + { + name: "intersect basic", + args: argsBasic, + want: New("a", "b", "c", "d"), + }, + { + name: "identical sets", + args: argsIdentical, + want: New("a", "b", "c"), + }, + { + name: "no overlap", + args: argsNoOverlap, + want: New("a", "b", "c", "d", "e", "f"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Union(tt.args.a, tt.args.b); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Union() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestXor(t *testing.T) { + tests := []struct { + name string + args args + want Set[string] + }{ + { + name: "xor basic", + args: argsBasic, + want: New("a", "d"), + }, + { + name: "identical sets", + args: argsIdentical, + want: New[string](), + }, + { + name: "no overlap", + args: argsNoOverlap, + want: New("a", "b", "c", "d", "e", "f"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Xor(tt.args.a, tt.args.b); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Xor() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEqual(t *testing.T) { + tests := []struct { + name string + args args + want bool + }{ + { + name: "equal basic", + args: argsBasic, + want: false, + }, + { + name: "identical sets", + args: argsIdentical, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Equal(tt.args.a, tt.args.b); got != tt.want { + t.Errorf("Equal() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSubset(t *testing.T) { + type args struct { + a Set[string] + b Set[string] + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "subset basic", + args: args{ + a: New("a", "b"), + b: New("a", "b", "c"), + }, + want: true, + }, + { + name: "subset basic false", + args: args{ + a: New("a", "b", "d"), + b: New("a", "b", "c"), + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Subset(tt.args.a, tt.args.b); got != tt.want { + t.Errorf("Subset() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSuperset(t *testing.T) { + type args struct { + a Set[string] + b Set[string] + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "superset basic", + args: args{ + a: New("a", "b", "c"), + b: New("a", "b"), + }, + want: true, + }, + { + name: "superset basic false", + args: args{ + a: New("a", "b", "c"), + b: New("a", "b", "d"), + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Superset(tt.args.a, tt.args.b); got != tt.want { + t.Errorf("Superset() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestDisjoint(t *testing.T) { + type args struct { + a Set[string] + b Set[string] + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "disjoint basic", + args: args{ + a: New("a", "b"), + b: New("c", "d"), + }, + want: true, + }, + { + name: "disjoint basic false", + args: args{ + a: New("a", "b"), + b: New("b", "c"), + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Disjoint(tt.args.a, tt.args.b); got != tt.want { + t.Errorf("Disjoint() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/backend/pkgs/set/set.go b/backend/pkgs/set/set.go new file mode 100644 index 0000000..f2ffecc --- /dev/null +++ b/backend/pkgs/set/set.go @@ -0,0 +1,56 @@ +package set + +type key interface { + comparable +} + +type Set[T key] struct { + mp map[T]struct{} +} + +func New[T key](v ...T) Set[T] { + mp := make(map[T]struct{}, len(v)) + + s := Set[T]{mp} + + s.Insert(v...) + return s +} + +func (s Set[T]) Insert(v ...T) { + for _, e := range v { + s.mp[e] = struct{}{} + } +} + +func (s Set[T]) Remove(v ...T) { + for _, e := range v { + delete(s.mp, e) + } +} + +func (s Set[T]) Contains(v T) bool { + _, ok := s.mp[v] + return ok +} + +func (s Set[T]) ContainsAll(v ...T) bool { + for _, e := range v { + if !s.Contains(e) { + return false + } + } + return true +} + +func (s Set[T]) Slice() []T { + slice := make([]T, 0, len(s.mp)) + for k := range s.mp { + slice = append(slice, k) + } + return slice +} + +func (s Set[T]) Len() int { + return len(s.mp) +} diff --git a/backend/pkgs/set/set_test.go b/backend/pkgs/set/set_test.go new file mode 100644 index 0000000..b571298 --- /dev/null +++ b/backend/pkgs/set/set_test.go @@ -0,0 +1,255 @@ +package set + +import ( + "reflect" + "testing" +) + +func TestNew(t *testing.T) { + type args struct { + v []string + } + tests := []struct { + name string + args args + want Set[string] + }{ + { + name: "new", + args: args{ + v: []string{"a", "b", "c"}, + }, + want: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + }, + }, + }, + { + name: "new empty", + args: args{ + v: []string{}, + }, + want: Set[string]{ + mp: map[string]struct{}{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := New(tt.args.v...); !reflect.DeepEqual(got, tt.want) { + t.Errorf("New() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSet_Insert(t *testing.T) { + type args struct { + v []string + } + tests := []struct { + name string + s Set[string] + args args + want Set[string] + }{ + { + name: "insert", + s: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + }, + }, + args: args{ + v: []string{"d", "e", "f"}, + }, + want: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + "d": {}, + "e": {}, + "f": {}, + }, + }, + }, + { + name: "insert empty", + s: Set[string]{ + mp: map[string]struct{}{}, + }, + args: args{ + v: []string{"a", "b", "c"}, + }, + want: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.s.Insert(tt.args.v...) + if !reflect.DeepEqual(tt.s, tt.want) { + t.Errorf("Set.Insert() = %v, want %v", tt.s, tt.want) + } + }) + } +} + +func TestSet_Delete(t *testing.T) { + type args struct { + v []string + } + tests := []struct { + name string + s Set[string] + args args + want Set[string] + }{ + { + name: "insert", + s: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + "d": {}, + "e": {}, + "f": {}, + }, + }, + args: args{ + v: []string{"d", "e", "f"}, + }, + want: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + }, + }, + }, + { + name: "delete empty", + s: Set[string]{ + mp: map[string]struct{}{}, + }, + args: args{ + v: []string{}, + }, + want: Set[string]{ + mp: map[string]struct{}{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.s.Remove(tt.args.v...) + if !reflect.DeepEqual(tt.s, tt.want) { + t.Errorf("Set.Delete() = %v, want %v", tt.s, tt.want) + } + }) + } +} + +func TestSet_ContainsAll(t *testing.T) { + type args struct { + v []string + } + tests := []struct { + name string + s Set[string] + args args + want bool + }{ + { + name: "contains", + s: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + }, + }, + args: args{ + v: []string{"a", "b", "c"}, + }, + want: true, + }, + { + name: "contains empty", + s: Set[string]{ + mp: map[string]struct{}{}, + }, + args: args{ + v: []string{}, + }, + want: true, + }, + { + name: "not contains", + s: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + }, + }, + args: args{ + v: []string{"d", "e", "f"}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.ContainsAll(tt.args.v...); got != tt.want { + t.Errorf("Set.ContainsAll() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSet_Slice(t *testing.T) { + tests := []struct { + name string + s Set[string] + want []string + }{ + { + name: "slice", + s: Set[string]{ + mp: map[string]struct{}{ + "a": {}, + "b": {}, + "c": {}, + }, + }, + want: []string{"a", "b", "c"}, + }, + { + name: "slice empty", + s: Set[string]{ + mp: map[string]struct{}{}, + }, + want: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.s.Slice(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Set.Slice() = %v, want %v", got, tt.want) + } + }) + } +}