homebox/backend/internal/data/repo/repo_group.go
2024-05-24 19:46:32 -05:00

315 lines
8.8 KiB
Go

package repo
import (
"context"
"strings"
"time"
"entgo.io/ent/dialect/sql"
"github.com/google/uuid"
"github.com/hay-kot/homebox/backend/internal/data/ent"
"github.com/hay-kot/homebox/backend/internal/data/ent/group"
"github.com/hay-kot/homebox/backend/internal/data/ent/groupinvitationtoken"
"github.com/hay-kot/homebox/backend/internal/data/ent/item"
"github.com/hay-kot/homebox/backend/internal/data/ent/label"
"github.com/hay-kot/homebox/backend/internal/data/ent/location"
)
type GroupRepository struct {
db *ent.Client
groupMapper MapFunc[*ent.Group, Group]
invitationMapper MapFunc[*ent.GroupInvitationToken, GroupInvitation]
}
func NewGroupRepository(db *ent.Client) *GroupRepository {
gmap := func(g *ent.Group) Group {
return Group{
ID: g.ID,
Name: g.Name,
CreatedAt: g.CreatedAt,
UpdatedAt: g.UpdatedAt,
Currency: strings.ToUpper(g.Currency),
}
}
imap := func(i *ent.GroupInvitationToken) GroupInvitation {
return GroupInvitation{
ID: i.ID,
ExpiresAt: i.ExpiresAt,
Uses: i.Uses,
Group: gmap(i.Edges.Group),
}
}
return &GroupRepository{
db: db,
groupMapper: gmap,
invitationMapper: imap,
}
}
type (
Group struct {
ID uuid.UUID `json:"id,omitempty"`
Name string `json:"name,omitempty"`
CreatedAt time.Time `json:"createdAt,omitempty"`
UpdatedAt time.Time `json:"updatedAt,omitempty"`
Currency string `json:"currency,omitempty"`
}
GroupUpdate struct {
Name string `json:"name"`
Currency string `json:"currency"`
}
GroupInvitationCreate struct {
Token []byte `json:"-"`
ExpiresAt time.Time `json:"expiresAt"`
Uses int `json:"uses"`
}
GroupInvitation struct {
ID uuid.UUID `json:"id"`
ExpiresAt time.Time `json:"expiresAt"`
Uses int `json:"uses"`
Group Group `json:"group"`
}
GroupStatistics struct {
TotalUsers int `json:"totalUsers"`
TotalItems int `json:"totalItems"`
TotalLocations int `json:"totalLocations"`
TotalLabels int `json:"totalLabels"`
TotalItemPrice float64 `json:"totalItemPrice"`
TotalWithWarranty int `json:"totalWithWarranty"`
}
ValueOverTimeEntry struct {
Date time.Time `json:"date"`
Value float64 `json:"value"`
Name string `json:"name"`
}
ValueOverTime struct {
PriceAtStart float64 `json:"valueAtStart"`
PriceAtEnd float64 `json:"valueAtEnd"`
Start time.Time `json:"start"`
End time.Time `json:"end"`
Entries []ValueOverTimeEntry `json:"entries"`
}
TotalsByOrganizer struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Total float64 `json:"total"`
}
)
func (r *GroupRepository) GetAllGroups(ctx context.Context) ([]Group, error) {
return r.groupMapper.MapEachErr(r.db.Group.Query().All(ctx))
}
func (r *GroupRepository) StatsLocationsByPurchasePrice(ctx context.Context, groupID uuid.UUID) ([]TotalsByOrganizer, error) {
var v []TotalsByOrganizer
err := r.db.Location.Query().
Where(
location.HasGroupWith(group.ID(groupID)),
).
GroupBy(location.FieldID, location.FieldName).
Aggregate(func(sq *sql.Selector) string {
t := sql.Table(item.Table)
sq.Join(t).On(sq.C(location.FieldID), t.C(item.LocationColumn))
return sql.As(sql.Sum(t.C(item.FieldPurchasePrice)), "total")
}).
Scan(ctx, &v)
if err != nil {
return nil, err
}
return v, err
}
func (r *GroupRepository) StatsLabelsByPurchasePrice(ctx context.Context, groupID uuid.UUID) ([]TotalsByOrganizer, error) {
var v []TotalsByOrganizer
err := r.db.Label.Query().
Where(
label.HasGroupWith(group.ID(groupID)),
).
GroupBy(label.FieldID, label.FieldName).
Aggregate(func(sq *sql.Selector) string {
itemTable := sql.Table(item.Table)
jt := sql.Table(label.ItemsTable)
sq.Join(jt).On(sq.C(label.FieldID), jt.C(label.ItemsPrimaryKey[0]))
sq.Join(itemTable).On(jt.C(label.ItemsPrimaryKey[1]), itemTable.C(item.FieldID))
return sql.As(sql.Sum(itemTable.C(item.FieldPurchasePrice)), "total")
}).
Scan(ctx, &v)
if err != nil {
return nil, err
}
return v, err
}
func (r *GroupRepository) StatsPurchasePrice(ctx context.Context, groupID uuid.UUID, start, end time.Time) (*ValueOverTime, error) {
// Get the Totals for the Start and End of the Given Time Period
q := `
SELECT
(SELECT Sum(purchase_price)
FROM items
WHERE group_items = ?
AND items.archived = false
AND items.created_at < ?) AS price_at_start,
(SELECT Sum(purchase_price)
FROM items
WHERE group_items = ?
AND items.archived = false
AND items.created_at < ?) AS price_at_end
`
stats := ValueOverTime{
Start: start,
End: end,
}
var maybeStart *float64
var maybeEnd *float64
row := r.db.Sql().QueryRowContext(ctx, q, groupID, sqliteDateFormat(start), groupID, sqliteDateFormat(end))
err := row.Scan(&maybeStart, &maybeEnd)
if err != nil {
return nil, err
}
stats.PriceAtStart = orDefault(maybeStart, 0)
stats.PriceAtEnd = orDefault(maybeEnd, 0)
var v []struct {
Name string `json:"name"`
CreatedAt time.Time `json:"created_at"`
PurchasePrice float64 `json:"purchase_price"`
}
// Get Created Date and Price of all items between start and end
err = r.db.Item.Query().
Where(
item.HasGroupWith(group.ID(groupID)),
item.CreatedAtGTE(start),
item.CreatedAtLTE(end),
item.Archived(false),
).
Select(
item.FieldName,
item.FieldCreatedAt,
item.FieldPurchasePrice,
).
Scan(ctx, &v)
if err != nil {
return nil, err
}
stats.Entries = make([]ValueOverTimeEntry, len(v))
for i, vv := range v {
stats.Entries[i] = ValueOverTimeEntry{
Date: vv.CreatedAt,
Value: vv.PurchasePrice,
}
}
return &stats, nil
}
func (r *GroupRepository) StatsGroup(ctx context.Context, groupID uuid.UUID) (GroupStatistics, error) {
q := `
SELECT
(SELECT COUNT(*) FROM users WHERE group_users = ?) AS total_users,
(SELECT COUNT(*) FROM items WHERE group_items = ? AND items.archived = false) AS total_items,
(SELECT COUNT(*) FROM locations WHERE group_locations = ?) AS total_locations,
(SELECT COUNT(*) FROM labels WHERE group_labels = ?) AS total_labels,
(SELECT SUM(purchase_price*quantity) FROM items WHERE group_items = ? AND items.archived = false) AS total_item_price,
(SELECT COUNT(*)
FROM items
WHERE group_items = ?
AND items.archived = false
AND (items.lifetime_warranty = true OR items.warranty_expires > date())
) AS total_with_warranty
`
var stats GroupStatistics
row := r.db.Sql().QueryRowContext(ctx, q, groupID, groupID, groupID, groupID, groupID, groupID)
var maybeTotalItemPrice *float64
var maybeTotalWithWarranty *int
err := row.Scan(&stats.TotalUsers, &stats.TotalItems, &stats.TotalLocations, &stats.TotalLabels, &maybeTotalItemPrice, &maybeTotalWithWarranty)
if err != nil {
return GroupStatistics{}, err
}
stats.TotalItemPrice = orDefault(maybeTotalItemPrice, 0)
stats.TotalWithWarranty = orDefault(maybeTotalWithWarranty, 0)
return stats, nil
}
func (r *GroupRepository) GroupCreate(ctx context.Context, name string) (Group, error) {
return r.groupMapper.MapErr(r.db.Group.Create().
SetName(name).
Save(ctx))
}
func (r *GroupRepository) GroupUpdate(ctx context.Context, groupID uuid.UUID, data GroupUpdate) (Group, error) {
entity, err := r.db.Group.UpdateOneID(groupID).
SetName(data.Name).
SetCurrency(strings.ToLower(data.Currency)).
Save(ctx)
return r.groupMapper.MapErr(entity, err)
}
func (r *GroupRepository) GroupByID(ctx context.Context, groupID uuid.UUID) (Group, error) {
return r.groupMapper.MapErr(r.db.Group.Get(ctx, groupID))
}
func (r *GroupRepository) InvitationGet(ctx context.Context, token []byte) (GroupInvitation, error) {
return r.invitationMapper.MapErr(r.db.GroupInvitationToken.Query().
Where(groupinvitationtoken.Token(token)).
WithGroup().
Only(ctx))
}
func (r *GroupRepository) InvitationCreate(ctx context.Context, groupID uuid.UUID, invite GroupInvitationCreate) (GroupInvitation, error) {
entity, err := r.db.GroupInvitationToken.Create().
SetGroupID(groupID).
SetToken(invite.Token).
SetExpiresAt(invite.ExpiresAt).
SetUses(invite.Uses).
Save(ctx)
if err != nil {
return GroupInvitation{}, err
}
return r.InvitationGet(ctx, entity.Token)
}
func (r *GroupRepository) InvitationUpdate(ctx context.Context, id uuid.UUID, uses int) error {
_, err := r.db.GroupInvitationToken.UpdateOneID(id).SetUses(uses).Save(ctx)
return err
}
// InvitationPurge removes all expired invitations or those that have been used up.
// It returns the number of deleted invitations.
func (r *GroupRepository) InvitationPurge(ctx context.Context) (amount int, err error) {
q := r.db.GroupInvitationToken.Delete()
q.Where(groupinvitationtoken.Or(
groupinvitationtoken.ExpiresAtLT(time.Now()),
groupinvitationtoken.UsesLTE(0),
))
return q.Exec(ctx)
}