From 508e2e59bd582d8dd7133a2293e97b7d4f985c7a Mon Sep 17 00:00:00 2001 From: Hayden <64056131+hay-kot@users.noreply.github.com> Date: Mon, 5 Sep 2022 00:26:21 -0800 Subject: [PATCH] tests: improve repo package coverage (#3) * refactor and add repo tests * add CI name * use atomic for test shutdown * use go 1.19 * add timeout --- .github/workflows/frontend.yaml | 6 +- .github/workflows/go.yaml | 3 +- .github/workflows/publish.yaml | 2 +- backend/go.mod | 2 +- backend/internal/repo/main_test.go | 34 ++-- backend/internal/repo/repo_group.go | 10 +- backend/internal/repo/repo_group_test.go | 4 +- backend/internal/repo/repo_items.go | 53 +++++- backend/internal/repo/repo_items_test.go | 178 +++++++++++++++++++ backend/internal/repo/repo_labels.go | 12 +- backend/internal/repo/repo_labels_test.go | 105 +++++++++++ backend/internal/repo/repo_locations.go | 14 +- backend/internal/repo/repo_locations_test.go | 45 +++-- backend/internal/repo/repo_tokens.go | 12 +- backend/internal/repo/repo_tokens_test.go | 79 +++++--- backend/internal/repo/repo_users.go | 42 ++--- backend/internal/repo/repo_users_test.go | 94 +++++----- backend/internal/repo/repos_all.go | 20 +-- backend/internal/types/users_types.go | 4 +- backend/pkgs/server/server_test.go | 7 +- 20 files changed, 540 insertions(+), 186 deletions(-) create mode 100644 backend/internal/repo/repo_items_test.go create mode 100644 backend/internal/repo/repo_labels_test.go diff --git a/.github/workflows/frontend.yaml b/.github/workflows/frontend.yaml index db769ae..e763e90 100644 --- a/.github/workflows/frontend.yaml +++ b/.github/workflows/frontend.yaml @@ -1,3 +1,5 @@ +name: Frontend / Integration + on: push: branches: [main] @@ -14,7 +16,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 - uses: actions/setup-node@v3 with: @@ -27,7 +29,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 - uses: actions/setup-node@v3 with: diff --git a/.github/workflows/go.yaml b/.github/workflows/go.yaml index 3e28501..e293079 100644 --- a/.github/workflows/go.yaml +++ b/.github/workflows/go.yaml @@ -15,7 +15,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 - name: Install Task uses: arduino/setup-task@v1 @@ -28,6 +28,7 @@ jobs: # Optional: working directory, useful for monorepos working-directory: backend + args: --timeout=6m - name: Build API run: task api:build diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index f93e6ba..7677693 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -12,7 +12,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.18 + go-version: 1.19 - name: login to container registry run: docker login ghcr.io --username hay-kot --password $CR_PAT env: diff --git a/backend/go.mod b/backend/go.mod index 716d43d..0d9d4eb 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -1,6 +1,6 @@ module github.com/hay-kot/content/backend -go 1.18 +go 1.19 require ( entgo.io/ent v0.11.2 diff --git a/backend/internal/repo/main_test.go b/backend/internal/repo/main_test.go index eefc112..5644178 100644 --- a/backend/internal/repo/main_test.go +++ b/backend/internal/repo/main_test.go @@ -9,25 +9,34 @@ import ( "time" "github.com/hay-kot/content/backend/ent" + "github.com/hay-kot/content/backend/pkgs/faker" _ "github.com/mattn/go-sqlite3" ) var ( - testEntClient *ent.Client - testRepos *AllRepos - testUser *ent.User - testGroup *ent.Group + fk = faker.NewFaker() + + tClient *ent.Client + tRepos *AllRepos + tUser *ent.User + tGroup *ent.Group ) func bootstrap() { - ctx := context.Background() - testGroup, _ = testRepos.Groups.Create(ctx, "test-group") - testUser, _ = testRepos.Users.Create(ctx, UserFactory()) + var ( + err error + ctx = context.Background() + ) - if testGroup == nil || testUser == nil { - log.Fatal("Failed to bootstrap test data") + tGroup, err = tRepos.Groups.Create(ctx, "test-group") + if err != nil { + log.Fatal(err) } + tUser, err = tRepos.Users.Create(ctx, userFactory()) + if err != nil { + log.Fatal(err) + } } func TestMain(m *testing.M) { @@ -38,12 +47,13 @@ func TestMain(m *testing.M) { log.Fatalf("failed opening connection to sqlite: %v", err) } - if err := client.Schema.Create(context.Background()); err != nil { + err = client.Schema.Create(context.Background()) + if err != nil { log.Fatalf("failed creating schema resources: %v", err) } - testEntClient = client - testRepos = EntAllRepos(testEntClient) + tClient = client + tRepos = EntAllRepos(tClient) defer client.Close() bootstrap() diff --git a/backend/internal/repo/repo_group.go b/backend/internal/repo/repo_group.go index ee58508..c5c173e 100644 --- a/backend/internal/repo/repo_group.go +++ b/backend/internal/repo/repo_group.go @@ -7,14 +7,16 @@ import ( "github.com/hay-kot/content/backend/ent" ) -type EntGroupRepository struct { +type GroupRepository struct { db *ent.Client } -func (r *EntGroupRepository) Create(ctx context.Context, name string) (*ent.Group, error) { - return r.db.Group.Create().SetName(name).Save(ctx) +func (r *GroupRepository) Create(ctx context.Context, name string) (*ent.Group, error) { + return r.db.Group.Create(). + SetName(name). + Save(ctx) } -func (r *EntGroupRepository) GetOneId(ctx context.Context, id uuid.UUID) (*ent.Group, error) { +func (r *GroupRepository) GetOneId(ctx context.Context, id uuid.UUID) (*ent.Group, error) { return r.db.Group.Get(ctx, id) } diff --git a/backend/internal/repo/repo_group_test.go b/backend/internal/repo/repo_group_test.go index 69110d8..41bf686 100644 --- a/backend/internal/repo/repo_group_test.go +++ b/backend/internal/repo/repo_group_test.go @@ -8,13 +8,13 @@ import ( ) func Test_Group_Create(t *testing.T) { - g, err := testRepos.Groups.Create(context.Background(), "test") + g, err := tRepos.Groups.Create(context.Background(), "test") assert.NoError(t, err) assert.Equal(t, "test", g.Name) // Get by ID - foundGroup, err := testRepos.Groups.GetOneId(context.Background(), g.ID) + foundGroup, err := tRepos.Groups.GetOneId(context.Background(), g.ID) assert.NoError(t, err) assert.Equal(t, g.ID, foundGroup.ID) } diff --git a/backend/internal/repo/repo_items.go b/backend/internal/repo/repo_items.go index 877f5b9..09a2101 100644 --- a/backend/internal/repo/repo_items.go +++ b/backend/internal/repo/repo_items.go @@ -33,19 +33,54 @@ func (e *ItemsRepository) GetAll(ctx context.Context, gid uuid.UUID) ([]*ent.Ite } func (e *ItemsRepository) Create(ctx context.Context, gid uuid.UUID, data types.ItemCreate) (*ent.Item, error) { - return e.db.Item.Create(). + q := e.db.Item.Create(). SetName(data.Name). SetDescription(data.Description). SetGroupID(gid). - AddLabelIDs(data.LabelIDs...). + SetLocationID(data.LocationID) + + if data.LabelIDs != nil && len(data.LabelIDs) > 0 { + q.AddLabelIDs(data.LabelIDs...) + } + + result, err := q.Save(ctx) + if err != nil { + return nil, err + } + + return e.GetOne(ctx, result.ID) +} + +func (e *ItemsRepository) Delete(ctx context.Context, id uuid.UUID) error { + return e.db.Item.DeleteOneID(id).Exec(ctx) +} + +func (e *ItemsRepository) Update(ctx context.Context, data types.ItemUpdate) (*ent.Item, error) { + q := e.db.Item.UpdateOneID(data.ID). + SetName(data.Name). + SetDescription(data.Description). SetLocationID(data.LocationID). - Save(ctx) -} + SetSerialNumber(data.SerialNumber). + SetModelNumber(data.ModelNumber). + SetManufacturer(data.Manufacturer). + SetPurchaseTime(data.PurchaseTime). + SetPurchaseFrom(data.PurchaseFrom). + SetPurchasePrice(data.PurchasePrice). + SetSoldTime(data.SoldTime). + SetSoldTo(data.SoldTo). + SetSoldPrice(data.SoldPrice). + SetSoldNotes(data.SoldNotes). + SetNotes(data.Notes) -func (e *ItemsRepository) Delete(ctx context.Context, gid uuid.UUID, id uuid.UUID) error { - panic("implement me") -} + if data.LabelIDs != nil && len(data.LabelIDs) > 0 { + q.AddLabelIDs(data.LabelIDs...) + } -func (e *ItemsRepository) Update(ctx context.Context, gid uuid.UUID, data types.ItemUpdate) (*ent.Item, error) { - panic("implement me") + err := q.Exec(ctx) + + if err != nil { + return nil, err + } + + return e.GetOne(ctx, data.ID) } diff --git a/backend/internal/repo/repo_items_test.go b/backend/internal/repo/repo_items_test.go new file mode 100644 index 0000000..66ef7d5 --- /dev/null +++ b/backend/internal/repo/repo_items_test.go @@ -0,0 +1,178 @@ +package repo + +import ( + "context" + "testing" + "time" + + "github.com/hay-kot/content/backend/ent" + "github.com/hay-kot/content/backend/internal/types" + "github.com/stretchr/testify/assert" +) + +func itemFactory() types.ItemCreate { + return types.ItemCreate{ + Name: fk.RandomString(10), + Description: fk.RandomString(100), + } +} + +func useItems(t *testing.T, len int) ([]*ent.Item, func()) { + t.Helper() + + location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) + assert.NoError(t, err) + + items := make([]*ent.Item, len) + for i := 0; i < len; i++ { + itm := itemFactory() + itm.LocationID = location.ID + + item, err := tRepos.Items.Create(context.Background(), tGroup.ID, itm) + assert.NoError(t, err) + items[i] = item + } + + return items, func() { + for _, item := range items { + err := tRepos.Items.Delete(context.Background(), item.ID) + assert.NoError(t, err) + } + } +} + +func TestItemsRepository_GetOne(t *testing.T) { + entity, cleanup := useItems(t, 3) + defer cleanup() + + for _, item := range entity { + result, err := tRepos.Items.GetOne(context.Background(), item.ID) + assert.NoError(t, err) + assert.Equal(t, item.ID, result.ID) + } +} + +func TestItemsRepository_GetAll(t *testing.T) { + length := 10 + expected, cleanup := useItems(t, length) + defer cleanup() + + results, err := tRepos.Items.GetAll(context.Background(), tGroup.ID) + assert.NoError(t, err) + + assert.Equal(t, length, len(results)) + + for _, item := range results { + for _, expectedItem := range expected { + if item.ID == expectedItem.ID { + assert.Equal(t, expectedItem.ID, item.ID) + assert.Equal(t, expectedItem.Name, item.Name) + assert.Equal(t, expectedItem.Description, item.Description) + } + } + } +} + +func TestItemsRepository_Create(t *testing.T) { + location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) + assert.NoError(t, err) + + itm := itemFactory() + itm.LocationID = location.ID + + result, err := tRepos.Items.Create(context.Background(), tGroup.ID, itm) + assert.NoError(t, err) + assert.NotEmpty(t, result.ID) + + // Cleanup + err = tRepos.Locations.Delete(context.Background(), location.ID) + assert.NoError(t, err) + + err = tRepos.Items.Delete(context.Background(), result.ID) + assert.NoError(t, err) +} + +func TestItemsRepository_Create_Location(t *testing.T) { + location, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) + assert.NoError(t, err) + assert.NotEmpty(t, location.ID) + + item := itemFactory() + item.LocationID = location.ID + + // Create Resource + result, err := tRepos.Items.Create(context.Background(), tGroup.ID, item) + assert.NoError(t, err) + assert.NotEmpty(t, result.ID) + + // Get Resource + foundItem, err := tRepos.Items.GetOne(context.Background(), result.ID) + assert.NoError(t, err) + assert.Equal(t, result.ID, foundItem.ID) + assert.Equal(t, location.ID, foundItem.Edges.Location.ID) + + // Cleanup + err = tRepos.Locations.Delete(context.Background(), location.ID) + assert.NoError(t, err) + err = tRepos.Items.Delete(context.Background(), result.ID) + assert.NoError(t, err) +} + +func TestItemsRepository_Delete(t *testing.T) { + entities, _ := useItems(t, 3) + + for _, item := range entities { + err := tRepos.Items.Delete(context.Background(), item.ID) + assert.NoError(t, err) + } + + results, err := tRepos.Items.GetAll(context.Background(), tGroup.ID) + assert.NoError(t, err) + assert.Empty(t, results) +} + +func TestItemsRepository_Update(t *testing.T) { + entities, cleanup := useItems(t, 3) + defer cleanup() + + entity := entities[0] + + updateData := types.ItemUpdate{ + ID: entity.ID, + Name: entity.Name, + LocationID: entity.Edges.Location.ID, + SerialNumber: fk.RandomString(10), + LabelIDs: nil, + ModelNumber: fk.RandomString(10), + Manufacturer: fk.RandomString(10), + PurchaseTime: time.Now(), + PurchaseFrom: fk.RandomString(10), + PurchasePrice: 300.99, + SoldTime: time.Now(), + SoldTo: fk.RandomString(10), + SoldPrice: 300.99, + SoldNotes: fk.RandomString(10), + Notes: fk.RandomString(10), + } + + updatedEntity, err := tRepos.Items.Update(context.Background(), updateData) + assert.NoError(t, err) + + got, err := tRepos.Items.GetOne(context.Background(), updatedEntity.ID) + assert.NoError(t, err) + + assert.Equal(t, updateData.ID, got.ID) + assert.Equal(t, updateData.Name, got.Name) + assert.Equal(t, updateData.LocationID, got.Edges.Location.ID) + assert.Equal(t, updateData.SerialNumber, got.SerialNumber) + assert.Equal(t, updateData.ModelNumber, got.ModelNumber) + assert.Equal(t, updateData.Manufacturer, got.Manufacturer) + // assert.Equal(t, updateData.PurchaseTime, got.PurchaseTime) + assert.Equal(t, updateData.PurchaseFrom, got.PurchaseFrom) + assert.Equal(t, updateData.PurchasePrice, got.PurchasePrice) + // assert.Equal(t, updateData.SoldTime, got.SoldTime) + assert.Equal(t, updateData.SoldTo, got.SoldTo) + assert.Equal(t, updateData.SoldPrice, got.SoldPrice) + assert.Equal(t, updateData.SoldNotes, got.SoldNotes) + assert.Equal(t, updateData.Notes, got.Notes) +} diff --git a/backend/internal/repo/repo_labels.go b/backend/internal/repo/repo_labels.go index d54c852..bbdcfb1 100644 --- a/backend/internal/repo/repo_labels.go +++ b/backend/internal/repo/repo_labels.go @@ -10,11 +10,11 @@ import ( "github.com/hay-kot/content/backend/internal/types" ) -type EntLabelRepository struct { +type LabelRepository struct { db *ent.Client } -func (r *EntLabelRepository) Get(ctx context.Context, ID uuid.UUID) (*ent.Label, error) { +func (r *LabelRepository) Get(ctx context.Context, ID uuid.UUID) (*ent.Label, error) { return r.db.Label.Query(). Where(label.ID(ID)). WithGroup(). @@ -22,14 +22,14 @@ func (r *EntLabelRepository) Get(ctx context.Context, ID uuid.UUID) (*ent.Label, Only(ctx) } -func (r *EntLabelRepository) GetAll(ctx context.Context, groupId uuid.UUID) ([]*ent.Label, error) { +func (r *LabelRepository) GetAll(ctx context.Context, groupId uuid.UUID) ([]*ent.Label, error) { return r.db.Label.Query(). Where(label.HasGroupWith(group.ID(groupId))). WithGroup(). All(ctx) } -func (r *EntLabelRepository) Create(ctx context.Context, groupdId uuid.UUID, data types.LabelCreate) (*ent.Label, error) { +func (r *LabelRepository) Create(ctx context.Context, groupdId uuid.UUID, data types.LabelCreate) (*ent.Label, error) { label, err := r.db.Label.Create(). SetName(data.Name). SetDescription(data.Description). @@ -41,7 +41,7 @@ func (r *EntLabelRepository) Create(ctx context.Context, groupdId uuid.UUID, dat return label, err } -func (r *EntLabelRepository) Update(ctx context.Context, data types.LabelUpdate) (*ent.Label, error) { +func (r *LabelRepository) Update(ctx context.Context, data types.LabelUpdate) (*ent.Label, error) { _, err := r.db.Label.UpdateOneID(data.ID). SetName(data.Name). SetDescription(data.Description). @@ -55,6 +55,6 @@ func (r *EntLabelRepository) Update(ctx context.Context, data types.LabelUpdate) return r.Get(ctx, data.ID) } -func (r *EntLabelRepository) Delete(ctx context.Context, id uuid.UUID) error { +func (r *LabelRepository) Delete(ctx context.Context, id uuid.UUID) error { return r.db.Label.DeleteOneID(id).Exec(ctx) } diff --git a/backend/internal/repo/repo_labels_test.go b/backend/internal/repo/repo_labels_test.go new file mode 100644 index 0000000..f647753 --- /dev/null +++ b/backend/internal/repo/repo_labels_test.go @@ -0,0 +1,105 @@ +package repo + +import ( + "context" + "testing" + + "github.com/hay-kot/content/backend/ent" + "github.com/hay-kot/content/backend/internal/types" + "github.com/stretchr/testify/assert" +) + +func labelFactory() types.LabelCreate { + return types.LabelCreate{ + Name: fk.RandomString(10), + Description: fk.RandomString(100), + } +} + +func useLabels(t *testing.T, len int) ([]*ent.Label, func()) { + t.Helper() + + labels := make([]*ent.Label, len) + for i := 0; i < len; i++ { + itm := labelFactory() + + item, err := tRepos.Labels.Create(context.Background(), tGroup.ID, itm) + assert.NoError(t, err) + labels[i] = item + } + + return labels, func() { + for _, item := range labels { + err := tRepos.Labels.Delete(context.Background(), item.ID) + assert.NoError(t, err) + } + } +} + +func TestLabelRepository_Get(t *testing.T) { + labels, cleanup := useLabels(t, 1) + defer cleanup() + label := labels[0] + + // Get by ID + foundLoc, err := tRepos.Labels.Get(context.Background(), label.ID) + assert.NoError(t, err) + assert.Equal(t, label.ID, foundLoc.ID) +} + +func TestLabelRepositoryGetAll(t *testing.T) { + _, cleanup := useLabels(t, 10) + defer cleanup() + + all, err := tRepos.Labels.GetAll(context.Background(), tGroup.ID) + assert.NoError(t, err) + assert.Len(t, all, 10) +} + +func TestLabelRepository_Create(t *testing.T) { + loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory()) + assert.NoError(t, err) + + // Get by ID + foundLoc, err := tRepos.Labels.Get(context.Background(), loc.ID) + assert.NoError(t, err) + assert.Equal(t, loc.ID, foundLoc.ID) + + err = tRepos.Labels.Delete(context.Background(), loc.ID) + assert.NoError(t, err) +} + +func TestLabelRepository_Update(t *testing.T) { + loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory()) + assert.NoError(t, err) + + updateData := types.LabelUpdate{ + ID: loc.ID, + Name: fk.RandomString(10), + Description: fk.RandomString(100), + } + + update, err := tRepos.Labels.Update(context.Background(), updateData) + assert.NoError(t, err) + + foundLoc, err := tRepos.Labels.Get(context.Background(), loc.ID) + assert.NoError(t, err) + + assert.Equal(t, update.ID, foundLoc.ID) + assert.Equal(t, update.Name, foundLoc.Name) + assert.Equal(t, update.Description, foundLoc.Description) + + err = tRepos.Labels.Delete(context.Background(), loc.ID) + assert.NoError(t, err) +} + +func TestLabelRepository_Delete(t *testing.T) { + loc, err := tRepos.Labels.Create(context.Background(), tGroup.ID, labelFactory()) + assert.NoError(t, err) + + err = tRepos.Labels.Delete(context.Background(), loc.ID) + assert.NoError(t, err) + + _, err = tRepos.Labels.Get(context.Background(), loc.ID) + assert.Error(t, err) +} diff --git a/backend/internal/repo/repo_locations.go b/backend/internal/repo/repo_locations.go index 88a2724..b23b839 100644 --- a/backend/internal/repo/repo_locations.go +++ b/backend/internal/repo/repo_locations.go @@ -9,7 +9,7 @@ import ( "github.com/hay-kot/content/backend/internal/types" ) -type EntLocationRepository struct { +type LocationRepository struct { db *ent.Client } @@ -19,8 +19,8 @@ type LocationWithCount struct { } // GetALlWithCount returns all locations with item count field populated -func (r *EntLocationRepository) GetAll(ctx context.Context, groupId uuid.UUID) ([]LocationWithCount, error) { - query := ` +func (r *LocationRepository) GetAll(ctx context.Context, groupId uuid.UUID) ([]LocationWithCount, error) { + query := `--sql SELECT id, name, @@ -61,7 +61,7 @@ func (r *EntLocationRepository) GetAll(ctx context.Context, groupId uuid.UUID) ( return list, err } -func (r *EntLocationRepository) Get(ctx context.Context, ID uuid.UUID) (*ent.Location, error) { +func (r *LocationRepository) Get(ctx context.Context, ID uuid.UUID) (*ent.Location, error) { return r.db.Location.Query(). Where(location.ID(ID)). WithGroup(). @@ -71,7 +71,7 @@ func (r *EntLocationRepository) Get(ctx context.Context, ID uuid.UUID) (*ent.Loc Only(ctx) } -func (r *EntLocationRepository) Create(ctx context.Context, groupdId uuid.UUID, data types.LocationCreate) (*ent.Location, error) { +func (r *LocationRepository) Create(ctx context.Context, groupdId uuid.UUID, data types.LocationCreate) (*ent.Location, error) { location, err := r.db.Location.Create(). SetName(data.Name). SetDescription(data.Description). @@ -82,7 +82,7 @@ func (r *EntLocationRepository) Create(ctx context.Context, groupdId uuid.UUID, return location, err } -func (r *EntLocationRepository) Update(ctx context.Context, data types.LocationUpdate) (*ent.Location, error) { +func (r *LocationRepository) Update(ctx context.Context, data types.LocationUpdate) (*ent.Location, error) { _, err := r.db.Location.UpdateOneID(data.ID). SetName(data.Name). SetDescription(data.Description). @@ -95,6 +95,6 @@ func (r *EntLocationRepository) Update(ctx context.Context, data types.LocationU return r.Get(ctx, data.ID) } -func (r *EntLocationRepository) Delete(ctx context.Context, id uuid.UUID) error { +func (r *LocationRepository) Delete(ctx context.Context, id uuid.UUID) error { return r.db.Location.DeleteOneID(id).Exec(ctx) } diff --git a/backend/internal/repo/repo_locations_test.go b/backend/internal/repo/repo_locations_test.go index f2818ef..e8d2f54 100644 --- a/backend/internal/repo/repo_locations_test.go +++ b/backend/internal/repo/repo_locations_test.go @@ -5,12 +5,9 @@ import ( "testing" "github.com/hay-kot/content/backend/internal/types" - "github.com/hay-kot/content/backend/pkgs/faker" "github.com/stretchr/testify/assert" ) -var fk = faker.NewFaker() - func locationFactory() types.LocationCreate { return types.LocationCreate{ Name: fk.RandomString(10), @@ -18,28 +15,28 @@ func locationFactory() types.LocationCreate { } } -func Test_Locations_Get(t *testing.T) { - loc, err := testRepos.Locations.Create(context.Background(), testGroup.ID, locationFactory()) +func TestLocationRepository_Get(t *testing.T) { + loc, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) assert.NoError(t, err) // Get by ID - foundLoc, err := testRepos.Locations.Get(context.Background(), loc.ID) + foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID) assert.NoError(t, err) assert.Equal(t, loc.ID, foundLoc.ID) - err = testRepos.Locations.Delete(context.Background(), loc.ID) + err = tRepos.Locations.Delete(context.Background(), loc.ID) assert.NoError(t, err) } -func Test_LocationsGetAllWithCount(t *testing.T) { +func TestLocationRepositoryGetAllWithCount(t *testing.T) { ctx := context.Background() - result, err := testRepos.Locations.Create(ctx, testGroup.ID, types.LocationCreate{ + result, err := tRepos.Locations.Create(ctx, tGroup.ID, types.LocationCreate{ Name: fk.RandomString(10), Description: fk.RandomString(100), }) assert.NoError(t, err) - _, err = testRepos.Items.Create(ctx, testGroup.ID, types.ItemCreate{ + _, err = tRepos.Items.Create(ctx, tGroup.ID, types.ItemCreate{ Name: fk.RandomString(10), Description: fk.RandomString(100), LocationID: result.ID, @@ -47,7 +44,7 @@ func Test_LocationsGetAllWithCount(t *testing.T) { assert.NoError(t, err) - results, err := testRepos.Locations.GetAll(context.Background(), testGroup.ID) + results, err := tRepos.Locations.GetAll(context.Background(), tGroup.ID) assert.NoError(t, err) for _, loc := range results { @@ -58,21 +55,21 @@ func Test_LocationsGetAllWithCount(t *testing.T) { } -func Test_Locations_Create(t *testing.T) { - loc, err := testRepos.Locations.Create(context.Background(), testGroup.ID, locationFactory()) +func TestLocationRepository_Create(t *testing.T) { + loc, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) assert.NoError(t, err) // Get by ID - foundLoc, err := testRepos.Locations.Get(context.Background(), loc.ID) + foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID) assert.NoError(t, err) assert.Equal(t, loc.ID, foundLoc.ID) - err = testRepos.Locations.Delete(context.Background(), loc.ID) + err = tRepos.Locations.Delete(context.Background(), loc.ID) assert.NoError(t, err) } -func Test_Locations_Update(t *testing.T) { - loc, err := testRepos.Locations.Create(context.Background(), testGroup.ID, locationFactory()) +func TestLocationRepository_Update(t *testing.T) { + loc, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) assert.NoError(t, err) updateData := types.LocationUpdate{ @@ -81,27 +78,27 @@ func Test_Locations_Update(t *testing.T) { Description: fk.RandomString(100), } - update, err := testRepos.Locations.Update(context.Background(), updateData) + update, err := tRepos.Locations.Update(context.Background(), updateData) assert.NoError(t, err) - foundLoc, err := testRepos.Locations.Get(context.Background(), loc.ID) + foundLoc, err := tRepos.Locations.Get(context.Background(), loc.ID) assert.NoError(t, err) assert.Equal(t, update.ID, foundLoc.ID) assert.Equal(t, update.Name, foundLoc.Name) assert.Equal(t, update.Description, foundLoc.Description) - err = testRepos.Locations.Delete(context.Background(), loc.ID) + err = tRepos.Locations.Delete(context.Background(), loc.ID) assert.NoError(t, err) } -func Test_Locations_Delete(t *testing.T) { - loc, err := testRepos.Locations.Create(context.Background(), testGroup.ID, locationFactory()) +func TestLocationRepository_Delete(t *testing.T) { + loc, err := tRepos.Locations.Create(context.Background(), tGroup.ID, locationFactory()) assert.NoError(t, err) - err = testRepos.Locations.Delete(context.Background(), loc.ID) + err = tRepos.Locations.Delete(context.Background(), loc.ID) assert.NoError(t, err) - _, err = testRepos.Locations.Get(context.Background(), loc.ID) + _, err = tRepos.Locations.Get(context.Background(), loc.ID) assert.Error(t, err) } diff --git a/backend/internal/repo/repo_tokens.go b/backend/internal/repo/repo_tokens.go index d40360f..13d5c72 100644 --- a/backend/internal/repo/repo_tokens.go +++ b/backend/internal/repo/repo_tokens.go @@ -9,12 +9,12 @@ import ( "github.com/hay-kot/content/backend/internal/types" ) -type EntTokenRepository struct { +type TokenRepository struct { db *ent.Client } // GetUserFromToken get's a user from a token -func (r *EntTokenRepository) GetUserFromToken(ctx context.Context, token []byte) (*ent.User, error) { +func (r *TokenRepository) GetUserFromToken(ctx context.Context, token []byte) (*ent.User, error) { user, err := r.db.AuthTokens.Query(). Where(authtokens.Token(token)). Where(authtokens.ExpiresAtGTE(time.Now())). @@ -31,7 +31,7 @@ func (r *EntTokenRepository) GetUserFromToken(ctx context.Context, token []byte) } // Creates a token for a user -func (r *EntTokenRepository) CreateToken(ctx context.Context, createToken types.UserAuthTokenCreate) (types.UserAuthToken, error) { +func (r *TokenRepository) CreateToken(ctx context.Context, createToken types.UserAuthTokenCreate) (types.UserAuthToken, error) { tokenOut := types.UserAuthToken{} dbToken, err := r.db.AuthTokens.Create(). @@ -53,13 +53,13 @@ func (r *EntTokenRepository) CreateToken(ctx context.Context, createToken types. } // DeleteToken remove a single token from the database - equivalent to revoke or logout -func (r *EntTokenRepository) DeleteToken(ctx context.Context, token []byte) error { +func (r *TokenRepository) DeleteToken(ctx context.Context, token []byte) error { _, err := r.db.AuthTokens.Delete().Where(authtokens.Token(token)).Exec(ctx) return err } // PurgeExpiredTokens removes all expired tokens from the database -func (r *EntTokenRepository) PurgeExpiredTokens(ctx context.Context) (int, error) { +func (r *TokenRepository) PurgeExpiredTokens(ctx context.Context) (int, error) { tokensDeleted, err := r.db.AuthTokens.Delete().Where(authtokens.ExpiresAtLTE(time.Now())).Exec(ctx) if err != nil { @@ -69,7 +69,7 @@ func (r *EntTokenRepository) PurgeExpiredTokens(ctx context.Context) (int, error return tokensDeleted, nil } -func (r *EntTokenRepository) DeleteAll(ctx context.Context) (int, error) { +func (r *TokenRepository) DeleteAll(ctx context.Context) (int, error) { amount, err := r.db.AuthTokens.Delete().Exec(ctx) return amount, err } diff --git a/backend/internal/repo/repo_tokens_test.go b/backend/internal/repo/repo_tokens_test.go index be737ef..2ab453a 100644 --- a/backend/internal/repo/repo_tokens_test.go +++ b/backend/internal/repo/repo_tokens_test.go @@ -10,46 +10,69 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_EntAuthTokenRepo_CreateToken(t *testing.T) { - assert := assert.New(t) +func TestAuthTokenRepo_CreateToken(t *testing.T) { + asrt := assert.New(t) ctx := context.Background() + user := userFactory() - user := UserFactory() - - userOut, err := testRepos.Users.Create(ctx, user) - assert.NoError(err) + userOut, err := tRepos.Users.Create(ctx, user) + asrt.NoError(err) expiresAt := time.Now().Add(time.Hour) generatedToken := hasher.GenerateToken() - token, err := testRepos.AuthTokens.CreateToken(ctx, types.UserAuthTokenCreate{ + token, err := tRepos.AuthTokens.CreateToken(ctx, types.UserAuthTokenCreate{ TokenHash: generatedToken.Hash, ExpiresAt: expiresAt, UserID: userOut.ID, }) - assert.NoError(err) - assert.Equal(userOut.ID, token.UserID) - assert.Equal(expiresAt, token.ExpiresAt) + asrt.NoError(err) + asrt.Equal(userOut.ID, token.UserID) + asrt.Equal(expiresAt, token.ExpiresAt) // Cleanup - assert.NoError(testRepos.Users.Delete(ctx, userOut.ID)) - _, err = testRepos.AuthTokens.DeleteAll(ctx) - assert.NoError(err) + asrt.NoError(tRepos.Users.Delete(ctx, userOut.ID)) + _, err = tRepos.AuthTokens.DeleteAll(ctx) + asrt.NoError(err) } -func Test_EntAuthTokenRepo_GetUserByToken(t *testing.T) { +func TestAuthTokenRepo_DeleteToken(t *testing.T) { + asrt := assert.New(t) + ctx := context.Background() + user := userFactory() + + userOut, err := tRepos.Users.Create(ctx, user) + asrt.NoError(err) + + expiresAt := time.Now().Add(time.Hour) + + generatedToken := hasher.GenerateToken() + + _, err = tRepos.AuthTokens.CreateToken(ctx, types.UserAuthTokenCreate{ + TokenHash: generatedToken.Hash, + ExpiresAt: expiresAt, + UserID: userOut.ID, + }) + asrt.NoError(err) + + // Delete token + err = tRepos.AuthTokens.DeleteToken(ctx, []byte(generatedToken.Raw)) + asrt.NoError(err) +} + +func TestAuthTokenRepo_GetUserByToken(t *testing.T) { assert := assert.New(t) ctx := context.Background() - user := UserFactory() - userOut, _ := testRepos.Users.Create(ctx, user) + user := userFactory() + userOut, _ := tRepos.Users.Create(ctx, user) expiresAt := time.Now().Add(time.Hour) generatedToken := hasher.GenerateToken() - token, err := testRepos.AuthTokens.CreateToken(ctx, types.UserAuthTokenCreate{ + token, err := tRepos.AuthTokens.CreateToken(ctx, types.UserAuthTokenCreate{ TokenHash: generatedToken.Hash, ExpiresAt: expiresAt, UserID: userOut.ID, @@ -58,7 +81,7 @@ func Test_EntAuthTokenRepo_GetUserByToken(t *testing.T) { assert.NoError(err) // Get User from token - foundUser, err := testRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash) + foundUser, err := tRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash) assert.NoError(err) assert.Equal(userOut.ID, foundUser.ID) @@ -66,17 +89,17 @@ func Test_EntAuthTokenRepo_GetUserByToken(t *testing.T) { assert.Equal(userOut.Email, foundUser.Email) // Cleanup - assert.NoError(testRepos.Users.Delete(ctx, userOut.ID)) - _, err = testRepos.AuthTokens.DeleteAll(ctx) + assert.NoError(tRepos.Users.Delete(ctx, userOut.ID)) + _, err = tRepos.AuthTokens.DeleteAll(ctx) assert.NoError(err) } -func Test_EntAuthTokenRepo_PurgeExpiredTokens(t *testing.T) { +func TestAuthTokenRepo_PurgeExpiredTokens(t *testing.T) { assert := assert.New(t) ctx := context.Background() - user := UserFactory() - userOut, _ := testRepos.Users.Create(ctx, user) + user := userFactory() + userOut, _ := tRepos.Users.Create(ctx, user) createdTokens := []types.UserAuthToken{} @@ -84,7 +107,7 @@ func Test_EntAuthTokenRepo_PurgeExpiredTokens(t *testing.T) { expiresAt := time.Now() generatedToken := hasher.GenerateToken() - createdToken, err := testRepos.AuthTokens.CreateToken(ctx, types.UserAuthTokenCreate{ + createdToken, err := tRepos.AuthTokens.CreateToken(ctx, types.UserAuthTokenCreate{ TokenHash: generatedToken.Hash, ExpiresAt: expiresAt, UserID: userOut.ID, @@ -98,19 +121,19 @@ func Test_EntAuthTokenRepo_PurgeExpiredTokens(t *testing.T) { } // Purge expired tokens - tokensDeleted, err := testRepos.AuthTokens.PurgeExpiredTokens(ctx) + tokensDeleted, err := tRepos.AuthTokens.PurgeExpiredTokens(ctx) assert.NoError(err) assert.Equal(5, tokensDeleted) // Check if tokens are deleted for _, token := range createdTokens { - _, err := testRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash) + _, err := tRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash) assert.Error(err) } // Cleanup - assert.NoError(testRepos.Users.Delete(ctx, userOut.ID)) - _, err = testRepos.AuthTokens.DeleteAll(ctx) + assert.NoError(tRepos.Users.Delete(ctx, userOut.ID)) + _, err = tRepos.AuthTokens.DeleteAll(ctx) assert.NoError(err) } diff --git a/backend/internal/repo/repo_users.go b/backend/internal/repo/repo_users.go index 66cd7ac..f4acaf4 100644 --- a/backend/internal/repo/repo_users.go +++ b/backend/internal/repo/repo_users.go @@ -9,29 +9,29 @@ import ( "github.com/hay-kot/content/backend/internal/types" ) -type EntUserRepository struct { +type UserRepository struct { db *ent.Client } -func (e *EntUserRepository) GetOneId(ctx context.Context, id uuid.UUID) (*ent.User, error) { +func (e *UserRepository) GetOneId(ctx context.Context, id uuid.UUID) (*ent.User, error) { return e.db.User.Query(). Where(user.ID(id)). WithGroup(). Only(ctx) } -func (e *EntUserRepository) GetOneEmail(ctx context.Context, email string) (*ent.User, error) { +func (e *UserRepository) GetOneEmail(ctx context.Context, email string) (*ent.User, error) { return e.db.User.Query(). Where(user.Email(email)). WithGroup(). Only(ctx) } -func (e *EntUserRepository) GetAll(ctx context.Context) ([]*ent.User, error) { +func (e *UserRepository) GetAll(ctx context.Context) ([]*ent.User, error) { return e.db.User.Query().WithGroup().All(ctx) } -func (e *EntUserRepository) Create(ctx context.Context, usr types.UserCreate) (*ent.User, error) { +func (e *UserRepository) Create(ctx context.Context, usr types.UserCreate) (*ent.User, error) { err := usr.Validate() if err != nil { return &ent.User{}, err @@ -52,41 +52,27 @@ func (e *EntUserRepository) Create(ctx context.Context, usr types.UserCreate) (* return e.GetOneId(ctx, entUser.ID) } -func (e *EntUserRepository) Update(ctx context.Context, ID uuid.UUID, data types.UserUpdate) error { - bldr := e.db.User.Update().Where(user.ID(ID)) +func (e *UserRepository) Update(ctx context.Context, ID uuid.UUID, data types.UserUpdate) error { + q := e.db.User.Update(). + Where(user.ID(ID)). + SetName(data.Name). + SetEmail(data.Email) - if data.Name != nil { - bldr = bldr.SetName(*data.Name) - } - - if data.Email != nil { - bldr = bldr.SetEmail(*data.Email) - } - - // TODO: FUTURE - // if data.Password != nil { - // bldr = bldr.SetPassword(*data.Password) - // } - - // if data.IsSuperuser != nil { - // bldr = bldr.SetIsSuperuser(*data.IsSuperuser) - // } - - _, err := bldr.Save(ctx) + _, err := q.Save(ctx) return err } -func (e *EntUserRepository) Delete(ctx context.Context, id uuid.UUID) error { +func (e *UserRepository) Delete(ctx context.Context, id uuid.UUID) error { _, err := e.db.User.Delete().Where(user.ID(id)).Exec(ctx) return err } -func (e *EntUserRepository) DeleteAll(ctx context.Context) error { +func (e *UserRepository) DeleteAll(ctx context.Context) error { _, err := e.db.User.Delete().Exec(ctx) return err } -func (e *EntUserRepository) GetSuperusers(ctx context.Context) ([]*ent.User, error) { +func (e *UserRepository) GetSuperusers(ctx context.Context) ([]*ent.User, error) { users, err := e.db.User.Query().Where(user.IsSuperuser(true)).All(ctx) if err != nil { diff --git a/backend/internal/repo/repo_users_test.go b/backend/internal/repo/repo_users_test.go index 4f63b61..98f115d 100644 --- a/backend/internal/repo/repo_users_test.go +++ b/backend/internal/repo/repo_users_test.go @@ -7,31 +7,29 @@ import ( "github.com/hay-kot/content/backend/ent" "github.com/hay-kot/content/backend/internal/types" - "github.com/hay-kot/content/backend/pkgs/faker" "github.com/stretchr/testify/assert" ) -func UserFactory() types.UserCreate { - f := faker.NewFaker() +func userFactory() types.UserCreate { return types.UserCreate{ - Name: f.RandomString(10), - Email: f.RandomEmail(), - Password: f.RandomString(10), - IsSuperuser: f.RandomBool(), - GroupID: testGroup.ID, + Name: fk.RandomString(10), + Email: fk.RandomEmail(), + Password: fk.RandomString(10), + IsSuperuser: fk.RandomBool(), + GroupID: tGroup.ID, } } -func Test_EntUserRepo_GetOneEmail(t *testing.T) { +func TestUserRepo_GetOneEmail(t *testing.T) { assert := assert.New(t) - user := UserFactory() + user := userFactory() ctx := context.Background() - _, err := testRepos.Users.Create(ctx, user) + _, err := tRepos.Users.Create(ctx, user) assert.NoError(err) - foundUser, err := testRepos.Users.GetOneEmail(ctx, user.Email) + foundUser, err := tRepos.Users.GetOneEmail(ctx, user.Email) assert.NotNil(foundUser) assert.Nil(err) @@ -39,17 +37,17 @@ func Test_EntUserRepo_GetOneEmail(t *testing.T) { assert.Equal(user.Name, foundUser.Name) // Cleanup - err = testRepos.Users.DeleteAll(ctx) + err = tRepos.Users.DeleteAll(ctx) assert.NoError(err) } -func Test_EntUserRepo_GetOneId(t *testing.T) { +func TestUserRepo_GetOneId(t *testing.T) { assert := assert.New(t) - user := UserFactory() + user := userFactory() ctx := context.Background() - userOut, _ := testRepos.Users.Create(ctx, user) - foundUser, err := testRepos.Users.GetOneId(ctx, userOut.ID) + userOut, _ := tRepos.Users.Create(ctx, user) + foundUser, err := tRepos.Users.GetOneId(ctx, userOut.ID) assert.NotNil(foundUser) assert.Nil(err) @@ -57,17 +55,17 @@ func Test_EntUserRepo_GetOneId(t *testing.T) { assert.Equal(user.Name, foundUser.Name) // Cleanup - err = testRepos.Users.DeleteAll(ctx) + err = tRepos.Users.DeleteAll(ctx) assert.NoError(err) } -func Test_EntUserRepo_GetAll(t *testing.T) { +func TestUserRepo_GetAll(t *testing.T) { // Setup toCreate := []types.UserCreate{ - UserFactory(), - UserFactory(), - UserFactory(), - UserFactory(), + userFactory(), + userFactory(), + userFactory(), + userFactory(), } ctx := context.Background() @@ -75,12 +73,12 @@ func Test_EntUserRepo_GetAll(t *testing.T) { created := []*ent.User{} for _, usr := range toCreate { - usrOut, _ := testRepos.Users.Create(ctx, usr) + usrOut, _ := tRepos.Users.Create(ctx, usr) created = append(created, usrOut) } // Validate - allUsers, err := testRepos.Users.GetAll(ctx) + allUsers, err := tRepos.Users.GetAll(ctx) assert.NoError(t, err) assert.Equal(t, len(created), len(allUsers)) @@ -98,48 +96,64 @@ func Test_EntUserRepo_GetAll(t *testing.T) { } for _, usr := range created { - _ = testRepos.Users.Delete(ctx, usr.ID) + _ = tRepos.Users.Delete(ctx, usr.ID) } // Cleanup - err = testRepos.Users.DeleteAll(ctx) + err = tRepos.Users.DeleteAll(ctx) assert.NoError(t, err) } -func Test_EntUserRepo_Update(t *testing.T) { - t.Skip() +func TestUserRepo_Update(t *testing.T) { + user, err := tRepos.Users.Create(context.Background(), userFactory()) + assert.NoError(t, err) + + updateData := types.UserUpdate{ + Name: fk.RandomString(10), + Email: fk.RandomEmail(), + } + + // Update + err = tRepos.Users.Update(context.Background(), user.ID, updateData) + assert.NoError(t, err) + + // Validate + updated, err := tRepos.Users.GetOneId(context.Background(), user.ID) + assert.NoError(t, err) + assert.NotEqual(t, user.Name, updated.Name) + assert.NotEqual(t, user.Email, updated.Email) } -func Test_EntUserRepo_Delete(t *testing.T) { +func TestUserRepo_Delete(t *testing.T) { // Create 10 Users for i := 0; i < 10; i++ { - user := UserFactory() + user := userFactory() ctx := context.Background() - _, _ = testRepos.Users.Create(ctx, user) + _, _ = tRepos.Users.Create(ctx, user) } // Delete all ctx := context.Background() - allUsers, _ := testRepos.Users.GetAll(ctx) + allUsers, _ := tRepos.Users.GetAll(ctx) assert.Greater(t, len(allUsers), 0) - err := testRepos.Users.DeleteAll(ctx) + err := tRepos.Users.DeleteAll(ctx) assert.NoError(t, err) - allUsers, _ = testRepos.Users.GetAll(ctx) + allUsers, _ = tRepos.Users.GetAll(ctx) assert.Equal(t, len(allUsers), 0) } -func Test_EntUserRepo_GetSuperusers(t *testing.T) { +func TestUserRepo_GetSuperusers(t *testing.T) { // Create 10 Users superuser := 0 users := 0 for i := 0; i < 10; i++ { - user := UserFactory() + user := userFactory() ctx := context.Background() - _, _ = testRepos.Users.Create(ctx, user) + _, _ = tRepos.Users.Create(ctx, user) if user.IsSuperuser { superuser++ @@ -151,7 +165,7 @@ func Test_EntUserRepo_GetSuperusers(t *testing.T) { // Delete all ctx := context.Background() - superUsers, err := testRepos.Users.GetSuperusers(ctx) + superUsers, err := tRepos.Users.GetSuperusers(ctx) assert.NoError(t, err) for _, usr := range superUsers { @@ -159,6 +173,6 @@ func Test_EntUserRepo_GetSuperusers(t *testing.T) { } // Cleanup - err = testRepos.Users.DeleteAll(ctx) + err = tRepos.Users.DeleteAll(ctx) assert.NoError(t, err) } diff --git a/backend/internal/repo/repos_all.go b/backend/internal/repo/repos_all.go index 906d83c..3542728 100644 --- a/backend/internal/repo/repos_all.go +++ b/backend/internal/repo/repos_all.go @@ -4,21 +4,21 @@ import "github.com/hay-kot/content/backend/ent" // AllRepos is a container for all the repository interfaces type AllRepos struct { - Users *EntUserRepository - AuthTokens *EntTokenRepository - Groups *EntGroupRepository - Locations *EntLocationRepository - Labels *EntLabelRepository + Users *UserRepository + AuthTokens *TokenRepository + Groups *GroupRepository + Locations *LocationRepository + Labels *LabelRepository Items *ItemsRepository } func EntAllRepos(db *ent.Client) *AllRepos { return &AllRepos{ - Users: &EntUserRepository{db}, - AuthTokens: &EntTokenRepository{db}, - Groups: &EntGroupRepository{db}, - Locations: &EntLocationRepository{db}, - Labels: &EntLabelRepository{db}, + Users: &UserRepository{db}, + AuthTokens: &TokenRepository{db}, + Groups: &GroupRepository{db}, + Locations: &LocationRepository{db}, + Labels: &LabelRepository{db}, Items: &ItemsRepository{db}, } } diff --git a/backend/internal/types/users_types.go b/backend/internal/types/users_types.go index 2f4053b..0a10ec4 100644 --- a/backend/internal/types/users_types.go +++ b/backend/internal/types/users_types.go @@ -41,8 +41,8 @@ func (u UserCreate) Validate() error { } type UserUpdate struct { - Name *string `json:"name"` - Email *string `json:"email"` + Name string `json:"name"` + Email string `json:"email"` } type UserRegistration struct { diff --git a/backend/pkgs/server/server_test.go b/backend/pkgs/server/server_test.go index b69b3eb..669182b 100644 --- a/backend/pkgs/server/server_test.go +++ b/backend/pkgs/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "net/http" + "sync/atomic" "testing" "time" @@ -72,14 +73,14 @@ func Test_GracefulServerShutdownWithWorkers(t *testing.T) { } func Test_GracefulServerShutdownWithRequests(t *testing.T) { - isFinished := false + var isFinished atomic.Bool router := http.NewServeMux() // add long running handler func router.HandleFunc("/test", func(rw http.ResponseWriter, r *http.Request) { time.Sleep(time.Second * 3) - isFinished = true + isFinished.Store(true) }) svr := testServer(t, router) @@ -94,5 +95,5 @@ func Test_GracefulServerShutdownWithRequests(t *testing.T) { err := svr.Shutdown("test") assert.NoError(t, err) - assert.True(t, isFinished) + assert.True(t, isFinished.Load()) }