diff --git a/backend/internal/repo/map_helpers.go b/backend/internal/repo/map_helpers.go new file mode 100644 index 0000000..38a56b4 --- /dev/null +++ b/backend/internal/repo/map_helpers.go @@ -0,0 +1,44 @@ +package repo + +// errMapperFunc is a factory function that returns a mapper function that +// wraps the given mapper function but first will check for an error and +// return the error if present. +// +// Helpful for wrapping database calls that return both a value and an error +func errMapperFunc[T any, Y any](fn func(T) Y) func(T, error) (Y, error) { + return func(t T, err error) (Y, error) { + if err != nil { + var zero Y + return zero, err + } + + return fn(t), nil + } +} + +// TODO: Future Usage +// func mapEachFunc[T any, Y any](fn func(T) Y) func([]T) []Y { +// return func(items []T) []Y { +// result := make([]Y, len(items)) +// for i, item := range items { +// result[i] = fn(item) +// } + +// return result +// } +// } + +func mapEachFuncErr[T any, Y any](fn func(T) Y) func([]T, error) ([]Y, error) { + return func(items []T, err error) ([]Y, error) { + if err != nil { + return nil, err + } + + result := make([]Y, len(items)) + for i, item := range items { + result[i] = fn(item) + } + + return result, nil + } +} diff --git a/backend/internal/repo/repo_documents.go b/backend/internal/repo/repo_documents.go index 34cf088..09ec42c 100644 --- a/backend/internal/repo/repo_documents.go +++ b/backend/internal/repo/repo_documents.go @@ -28,26 +28,47 @@ type ( Title string Content io.Reader } + + DocumentOut struct { + ID uuid.UUID + Title string + Path string + } +) + +func mapDocumentOut(doc *ent.Document) DocumentOut { + return DocumentOut{ + ID: doc.ID, + Title: doc.Title, + Path: doc.Path, + } +} + +var ( + mapDocumentOutErr = errMapperFunc(mapDocumentOut) + mapDocumentOutEachErr = mapEachFuncErr(mapDocumentOut) ) func (r *DocumentRepository) path(gid uuid.UUID, ext string) string { return pathlib.Safe(filepath.Join(r.dir, gid.String(), "documents", uuid.NewString()+ext)) } -func (r *DocumentRepository) GetAll(ctx context.Context, gid uuid.UUID) ([]*ent.Document, error) { - return r.db.Document.Query(). +func (r *DocumentRepository) GetAll(ctx context.Context, gid uuid.UUID) ([]DocumentOut, error) { + return mapDocumentOutEachErr(r.db.Document. + Query(). Where(document.HasGroupWith(group.ID(gid))). - All(ctx) + All(ctx), + ) } -func (r *DocumentRepository) Get(ctx context.Context, id uuid.UUID) (*ent.Document, error) { - return r.db.Document.Get(ctx, id) +func (r *DocumentRepository) Get(ctx context.Context, id uuid.UUID) (DocumentOut, error) { + return mapDocumentOutErr(r.db.Document.Get(ctx, id)) } -func (r *DocumentRepository) Create(ctx context.Context, gid uuid.UUID, doc DocumentCreate) (*ent.Document, error) { +func (r *DocumentRepository) Create(ctx context.Context, gid uuid.UUID, doc DocumentCreate) (DocumentOut, error) { ext := filepath.Ext(doc.Title) if ext == "" { - return nil, ErrInvalidDocExtension + return DocumentOut{}, ErrInvalidDocExtension } path := r.path(gid, ext) @@ -55,30 +76,31 @@ func (r *DocumentRepository) Create(ctx context.Context, gid uuid.UUID, doc Docu parent := filepath.Dir(path) err := os.MkdirAll(parent, 0755) if err != nil { - return nil, err + return DocumentOut{}, err } f, err := os.Create(path) if err != nil { - return nil, err + return DocumentOut{}, err } _, err = io.Copy(f, doc.Content) if err != nil { - return nil, err + return DocumentOut{}, err } - return r.db.Document.Create(). + return mapDocumentOutErr(r.db.Document.Create(). SetGroupID(gid). SetTitle(doc.Title). SetPath(path). - Save(ctx) + Save(ctx), + ) } -func (r *DocumentRepository) Rename(ctx context.Context, id uuid.UUID, title string) (*ent.Document, error) { - return r.db.Document.UpdateOneID(id). +func (r *DocumentRepository) Rename(ctx context.Context, id uuid.UUID, title string) (DocumentOut, error) { + return mapDocumentOutErr(r.db.Document.UpdateOneID(id). SetTitle(title). - Save(ctx) + Save(ctx)) } func (r *DocumentRepository) Delete(ctx context.Context, id uuid.UUID) error { diff --git a/backend/internal/repo/repo_documents_test.go b/backend/internal/repo/repo_documents_test.go index bb1467a..06e631c 100644 --- a/backend/internal/repo/repo_documents_test.go +++ b/backend/internal/repo/repo_documents_test.go @@ -13,10 +13,10 @@ import ( "github.com/stretchr/testify/assert" ) -func useDocs(t *testing.T, num int) []*ent.Document { +func useDocs(t *testing.T, num int) []DocumentOut { t.Helper() - results := make([]*ent.Document, 0, num) + results := make([]DocumentOut, 0, num) ids := make([]uuid.UUID, 0, num) for i := 0; i < num; i++ { diff --git a/backend/internal/repo/repo_documents_tokens_test.go b/backend/internal/repo/repo_documents_tokens_test.go index ebd84ed..f80be0e 100644 --- a/backend/internal/repo/repo_documents_tokens_test.go +++ b/backend/internal/repo/repo_documents_tokens_test.go @@ -39,7 +39,9 @@ func TestDocumentTokensRepository_Create(t *testing.T) { }, want: &ent.DocumentToken{ Edges: ent.DocumentTokenEdges{ - Document: doc, + Document: &ent.Document{ + ID: doc.ID, + }, }, Token: []byte("token"), ExpiresAt: expires,