setup location repository

This commit is contained in:
Hayden 2022-08-30 19:21:18 -08:00
parent 630fe83de5
commit 9583847f94
11 changed files with 267 additions and 25 deletions

View file

@ -12,8 +12,23 @@ import (
_ "github.com/mattn/go-sqlite3"
)
var testEntClient *ent.Client
var testRepos *AllRepos
var (
testEntClient *ent.Client
testRepos *AllRepos
testUser *ent.User
testGroup *ent.Group
)
func bootstrap() {
ctx := context.Background()
testGroup, _ = testRepos.Groups.Create(ctx, "test-group")
testUser, _ = testRepos.Users.Create(ctx, UserFactory())
if testGroup == nil || testUser == nil {
log.Fatal("Failed to bootstrap test data")
}
}
func TestMain(m *testing.M) {
rand.Seed(int64(time.Now().Unix()))
@ -29,10 +44,9 @@ func TestMain(m *testing.M) {
testEntClient = client
testRepos = EntAllRepos(testEntClient)
defer client.Close()
m.Run()
bootstrap()
os.Exit(m.Run())
}

View file

@ -12,18 +12,9 @@ type EntGroupRepository struct {
}
func (r *EntGroupRepository) Create(ctx context.Context, name string) (*ent.Group, error) {
dbGroup, err := r.db.Group.Create().SetName(name).Save(ctx)
if err != nil {
return dbGroup, err
}
return dbGroup, nil
return r.db.Group.Create().SetName(name).Save(ctx)
}
func (r *EntGroupRepository) GetOneId(ctx context.Context, id uuid.UUID) (*ent.Group, error) {
dbGroup, err := r.db.Group.Get(ctx, id)
if err != nil {
return dbGroup, err
}
return dbGroup, nil
return r.db.Group.Get(ctx, id)
}

View file

@ -0,0 +1,20 @@
package repo
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
func Test_Group_Create(t *testing.T) {
g, err := testRepos.Groups.Create(context.Background(), "test")
assert.NoError(t, err)
assert.Equal(t, "test", g.Name)
// Get by ID
foundGroup, err := testRepos.Groups.GetOneId(context.Background(), g.ID)
assert.NoError(t, err)
assert.Equal(t, g.ID, foundGroup.ID)
}

View file

@ -0,0 +1,47 @@
package repo
import (
"context"
"github.com/google/uuid"
"github.com/hay-kot/content/backend/ent"
"github.com/hay-kot/content/backend/ent/group"
"github.com/hay-kot/content/backend/ent/location"
"github.com/hay-kot/content/backend/internal/types"
)
type EntLocationRepository struct {
db *ent.Client
}
func (r *EntLocationRepository) Get(ctx context.Context, ID uuid.UUID) (*ent.Location, error) {
return r.db.Location.Get(ctx, ID)
}
func (r *EntLocationRepository) GetAll(ctx context.Context, groupId uuid.UUID) ([]*ent.Location, error) {
return r.db.Location.Query().
Where(location.HasGroupWith(group.ID(groupId))).
All(ctx)
}
func (r *EntLocationRepository) Create(ctx context.Context, groupdId uuid.UUID, data types.LocationCreate) (*ent.Location, error) {
location, err := r.db.Location.Create().
SetName(data.Name).
SetDescription(data.Description).
SetGroupID(groupdId).
Save(ctx)
location.Edges.Group = &ent.Group{ID: groupdId} // bootstrap group ID
return location, err
}
func (r *EntLocationRepository) Update(ctx context.Context, data types.LocationUpdate) (*ent.Location, error) {
return r.db.Location.UpdateOneID(data.ID).
SetName(data.Name).
SetDescription(data.Description).
Save(ctx)
}
func (r *EntLocationRepository) Delete(ctx context.Context, id uuid.UUID) error {
return r.db.Location.DeleteOneID(id).Exec(ctx)
}

View file

@ -0,0 +1,100 @@
package repo
import (
"context"
"testing"
"github.com/hay-kot/content/backend/ent"
"github.com/hay-kot/content/backend/internal/types"
"github.com/hay-kot/content/backend/pkgs/faker"
"github.com/stretchr/testify/assert"
)
var fk = faker.NewFaker()
func locationFactory() types.LocationCreate {
return types.LocationCreate{
Name: fk.RandomString(10),
Description: fk.RandomString(100),
}
}
func Test_Locations_Get(t *testing.T) {
loc, err := testRepos.Locations.Create(context.Background(), testGroup.ID, locationFactory())
assert.NoError(t, err)
// Get by ID
foundLoc, err := testRepos.Locations.Get(context.Background(), loc.ID)
assert.NoError(t, err)
assert.Equal(t, loc.ID, foundLoc.ID)
testRepos.Locations.Delete(context.Background(), loc.ID)
}
func Test_Locations_GetAll(t *testing.T) {
created := make([]*ent.Location, 6)
for i := 0; i < 6; i++ {
result, err := testRepos.Locations.Create(context.Background(), testGroup.ID, types.LocationCreate{
Name: fk.RandomString(10),
Description: fk.RandomString(100),
})
assert.NoError(t, err)
created[i] = result
}
locations, err := testRepos.Locations.GetAll(context.Background(), testGroup.ID)
assert.NoError(t, err)
assert.Equal(t, 6, len(locations))
for _, loc := range created {
testRepos.Locations.Delete(context.Background(), loc.ID)
}
}
func Test_Locations_Create(t *testing.T) {
loc, err := testRepos.Locations.Create(context.Background(), testGroup.ID, locationFactory())
assert.NoError(t, err)
// Get by ID
foundLoc, err := testRepos.Locations.Get(context.Background(), loc.ID)
assert.NoError(t, err)
assert.Equal(t, loc.ID, foundLoc.ID)
testRepos.Locations.Delete(context.Background(), loc.ID)
}
func Test_Locations_Update(t *testing.T) {
loc, err := testRepos.Locations.Create(context.Background(), testGroup.ID, locationFactory())
assert.NoError(t, err)
updateData := types.LocationUpdate{
ID: loc.ID,
Name: fk.RandomString(10),
Description: fk.RandomString(100),
}
update, err := testRepos.Locations.Update(context.Background(), updateData)
assert.NoError(t, err)
foundLoc, err := testRepos.Locations.Get(context.Background(), loc.ID)
assert.NoError(t, err)
assert.Equal(t, update.ID, foundLoc.ID)
assert.Equal(t, update.Name, foundLoc.Name)
assert.Equal(t, update.Description, foundLoc.Description)
testRepos.Locations.Delete(context.Background(), loc.ID)
}
func Test_Locations_Delete(t *testing.T) {
loc, err := testRepos.Locations.Create(context.Background(), testGroup.ID, locationFactory())
assert.NoError(t, err)
err = testRepos.Locations.Delete(context.Background(), loc.ID)
assert.NoError(t, err)
_, err = testRepos.Locations.Get(context.Background(), loc.ID)
assert.Error(t, err)
}

View file

@ -16,7 +16,8 @@ func Test_EntAuthTokenRepo_CreateToken(t *testing.T) {
user := UserFactory()
userOut, _ := testRepos.Users.Create(ctx, user)
userOut, err := testRepos.Users.Create(ctx, user)
assert.NoError(err)
expiresAt := time.Now().Add(time.Hour)
@ -33,8 +34,8 @@ func Test_EntAuthTokenRepo_CreateToken(t *testing.T) {
assert.Equal(expiresAt, token.ExpiresAt)
// Cleanup
err = testRepos.Users.Delete(ctx, userOut.ID)
_, err = testRepos.AuthTokens.DeleteAll(ctx)
testRepos.Users.Delete(ctx, userOut.ID)
testRepos.AuthTokens.DeleteAll(ctx)
}
func Test_EntAuthTokenRepo_GetUserByToken(t *testing.T) {
@ -53,6 +54,8 @@ func Test_EntAuthTokenRepo_GetUserByToken(t *testing.T) {
UserID: userOut.ID,
})
assert.NoError(err)
// Get User from token
foundUser, err := testRepos.AuthTokens.GetUserFromToken(ctx, token.TokenHash)
@ -62,8 +65,8 @@ func Test_EntAuthTokenRepo_GetUserByToken(t *testing.T) {
assert.Equal(userOut.Email, foundUser.Email)
// Cleanup
err = testRepos.Users.Delete(ctx, userOut.ID)
_, err = testRepos.AuthTokens.DeleteAll(ctx)
testRepos.Users.Delete(ctx, userOut.ID)
testRepos.AuthTokens.DeleteAll(ctx)
}
func Test_EntAuthTokenRepo_PurgeExpiredTokens(t *testing.T) {
@ -105,6 +108,6 @@ func Test_EntAuthTokenRepo_PurgeExpiredTokens(t *testing.T) {
}
// Cleanup
err = testRepos.Users.Delete(ctx, userOut.ID)
_, err = testRepos.AuthTokens.DeleteAll(ctx)
testRepos.Users.Delete(ctx, userOut.ID)
testRepos.AuthTokens.DeleteAll(ctx)
}

View file

@ -34,7 +34,7 @@ func (e *EntUserRepository) GetOneEmail(ctx context.Context, email string) (*ent
}
func (e *EntUserRepository) GetAll(ctx context.Context) ([]*ent.User, error) {
users, err := e.db.User.Query().All(ctx)
users, err := e.db.User.Query().WithGroup().All(ctx)
if err != nil {
return nil, err

View file

@ -13,11 +13,13 @@ import (
func UserFactory() types.UserCreate {
f := faker.NewFaker()
return types.UserCreate{
Name: f.RandomString(10),
Email: f.RandomEmail(),
Password: f.RandomString(10),
IsSuperuser: f.RandomBool(),
GroupID: testGroup.ID,
}
}
@ -77,12 +79,19 @@ func Test_EntUserRepo_GetAll(t *testing.T) {
// Validate
allUsers, err := testRepos.Users.GetAll(ctx)
assert.Nil(t, err)
assert.NoError(t, err)
assert.Equal(t, len(created), len(allUsers))
for _, usr := range created {
fmt.Printf("%+v\n", usr)
assert.Contains(t, allUsers, usr)
for _, usr2 := range allUsers {
if usr.ID == usr2.ID {
assert.Equal(t, usr.Email, usr2.Email)
// Check groups are loaded
assert.NotNil(t, usr2.Edges.Group)
}
}
}
for _, usr := range created {

View file

@ -7,6 +7,7 @@ type AllRepos struct {
Users *EntUserRepository
AuthTokens *EntTokenRepository
Groups *EntGroupRepository
Locations *EntLocationRepository
}
func EntAllRepos(db *ent.Client) *AllRepos {
@ -14,5 +15,6 @@ func EntAllRepos(db *ent.Client) *AllRepos {
Users: &EntUserRepository{db},
AuthTokens: &EntTokenRepository{db},
Groups: &EntGroupRepository{db},
Locations: &EntLocationRepository{db},
}
}

View file

@ -0,0 +1,29 @@
package services
import (
"context"
"github.com/google/uuid"
"github.com/hay-kot/content/backend/ent"
"github.com/hay-kot/content/backend/internal/repo"
"github.com/hay-kot/content/backend/internal/types"
)
type LocationService struct {
repos *repo.AllRepos
}
func ToLocationOut(location *ent.Location, err error) (*types.LocationOut, error) {
return &types.LocationOut{
ID: location.ID,
GroupID: location.Edges.Group.ID,
Name: location.Name,
Description: location.Description,
CreatedAt: location.CreatedAt,
UpdatedAt: location.UpdatedAt,
}, err
}
func (svc *LocationService) GetAll(ctx context.Context, groupId uuid.UUID) ([]*types.LocationOut, error) {
panic("not implemented")
}

View file

@ -0,0 +1,27 @@
package types
import (
"time"
"github.com/google/uuid"
)
type LocationCreate struct {
Name string `json:"name"`
Description string `json:"description"`
}
type LocationUpdate struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
}
type LocationOut struct {
ID uuid.UUID `json:"id"`
GroupID uuid.UUID `json:"groupId"`
Name string `json:"name"`
Description string `json:"description"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}