diff --git a/digest/set.go b/digest/set.go index 271d35db..3fac41b4 100644 --- a/digest/set.go +++ b/digest/set.go @@ -4,6 +4,7 @@ import ( "errors" "sort" "strings" + "sync" ) var ( @@ -27,6 +28,7 @@ var ( // the complete set of digests. To mitigate collisions, an // appropriately long short code should be used. type Set struct { + mutex sync.RWMutex entries digestEntries } @@ -63,6 +65,8 @@ func checkShortMatch(alg Algorithm, hex, shortAlg, shortHex string) bool { // with an empty digest value. If multiple matches are found // ErrDigestAmbiguous will be returned with an empty digest value. func (dst *Set) Lookup(d string) (Digest, error) { + dst.mutex.RLock() + defer dst.mutex.RUnlock() if len(dst.entries) == 0 { return "", ErrDigestNotFound } @@ -101,13 +105,15 @@ func (dst *Set) Lookup(d string) (Digest, error) { return dst.entries[idx].digest, nil } -// Add adds the given digests to the set. An error will be returned +// Add adds the given digest to the set. An error will be returned // if the given digest is invalid. If the digest already exists in the -// table, this operation will be a no-op. +// set, this operation will be a no-op. func (dst *Set) Add(d Digest) error { if err := d.Validate(); err != nil { return err } + dst.mutex.Lock() + defer dst.mutex.Unlock() entry := &digestEntry{alg: d.Algorithm(), val: d.Hex(), digest: d} searchFunc := func(i int) bool { if dst.entries[i].val == entry.val { @@ -130,12 +136,56 @@ func (dst *Set) Add(d Digest) error { return nil } +// Remove removes the given digest from the set. An err will be +// returned if the given digest is invalid. If the digest does +// not exist in the set, this operation will be a no-op. +func (dst *Set) Remove(d Digest) error { + if err := d.Validate(); err != nil { + return err + } + dst.mutex.Lock() + defer dst.mutex.Unlock() + entry := &digestEntry{alg: d.Algorithm(), val: d.Hex(), digest: d} + searchFunc := func(i int) bool { + if dst.entries[i].val == entry.val { + return dst.entries[i].alg >= entry.alg + } + return dst.entries[i].val >= entry.val + } + idx := sort.Search(len(dst.entries), searchFunc) + // Not found if idx is after or value at idx is not digest + if idx == len(dst.entries) || dst.entries[idx].digest != d { + return nil + } + + entries := dst.entries + copy(entries[idx:], entries[idx+1:]) + entries = entries[:len(entries)-1] + dst.entries = entries + + return nil +} + +// All returns all the digests in the set +func (dst *Set) All() []Digest { + dst.mutex.RLock() + defer dst.mutex.RUnlock() + retValues := make([]Digest, len(dst.entries)) + for i := range dst.entries { + retValues[i] = dst.entries[i].digest + } + + return retValues +} + // ShortCodeTable returns a map of Digest to unique short codes. The // length represents the minimum value, the maximum length may be the // entire value of digest if uniqueness cannot be achieved without the // full value. This function will attempt to make short codes as short // as possible to be unique. func ShortCodeTable(dst *Set, length int) map[Digest]string { + dst.mutex.RLock() + defer dst.mutex.RUnlock() m := make(map[Digest]string, len(dst.entries)) l := length resetIdx := 0 diff --git a/digest/set_test.go b/digest/set_test.go index faeba6d3..0c0f650d 100644 --- a/digest/set_test.go +++ b/digest/set_test.go @@ -125,6 +125,66 @@ func TestAddDuplication(t *testing.T) { } } +func TestRemove(t *testing.T) { + digests, err := createDigests(10) + if err != nil { + t.Fatal(err) + } + + dset := NewSet() + for i := range digests { + if err := dset.Add(digests[i]); err != nil { + t.Fatal(err) + } + } + + dgst, err := dset.Lookup(digests[0].String()) + if err != nil { + t.Fatal(err) + } + if dgst != digests[0] { + t.Fatalf("Unexpected digest value:\n\tExpected: %s\n\tActual: %s", digests[0], dgst) + } + + if err := dset.Remove(digests[0]); err != nil { + t.Fatal(err) + } + + if _, err := dset.Lookup(digests[0].String()); err != ErrDigestNotFound { + t.Fatalf("Expected error %v when looking up removed digest, got %v", ErrDigestNotFound, err) + } +} + +func TestAll(t *testing.T) { + digests, err := createDigests(100) + if err != nil { + t.Fatal(err) + } + + dset := NewSet() + for i := range digests { + if err := dset.Add(digests[i]); err != nil { + t.Fatal(err) + } + } + + all := map[Digest]struct{}{} + for _, dgst := range dset.All() { + all[dgst] = struct{}{} + } + + if len(all) != len(digests) { + t.Fatalf("Unexpected number of unique digests found:\n\tExpected: %d\n\tActual: %d", len(digests), len(all)) + } + + for i, dgst := range digests { + if _, ok := all[dgst]; !ok { + t.Fatalf("Missing element at position %d: %s", i, dgst) + } + } + +} + func assertEqualShort(t *testing.T, actual, expected string) { if actual != expected { t.Fatalf("Unexpected short value:\n\tExpected: %s\n\tActual: %s", expected, actual) @@ -219,6 +279,29 @@ func benchLookupNTable(b *testing.B, n int, shortLen int) { } } +func benchRemoveNTable(b *testing.B, n int) { + digests, err := createDigests(n) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + dset := &Set{entries: digestEntries(make([]*digestEntry, 0, n))} + b.StopTimer() + for j := range digests { + if err = dset.Add(digests[j]); err != nil { + b.Fatal(err) + } + } + b.StartTimer() + for j := range digests { + if err = dset.Remove(digests[j]); err != nil { + b.Fatal(err) + } + } + } +} + func benchShortCodeNTable(b *testing.B, n int, shortLen int) { digests, err := createDigests(n) if err != nil { @@ -249,6 +332,18 @@ func BenchmarkAdd1000(b *testing.B) { benchAddNTable(b, 1000) } +func BenchmarkRemove10(b *testing.B) { + benchRemoveNTable(b, 10) +} + +func BenchmarkRemove100(b *testing.B) { + benchRemoveNTable(b, 100) +} + +func BenchmarkRemove1000(b *testing.B) { + benchRemoveNTable(b, 1000) +} + func BenchmarkLookup10(b *testing.B) { benchLookupNTable(b, 10, 12) }