implement custom http handler interface

This commit is contained in:
Hayden 2022-10-28 16:35:34 -08:00
parent 82269e8a95
commit 8a3d16860c
26 changed files with 752 additions and 480 deletions

View file

@ -76,9 +76,9 @@ func NewControllerV1(svc *services.AllServices, options ...func(*V1Controller))
// @Produce json // @Produce json
// @Success 200 {object} ApiSummary // @Success 200 {object} ApiSummary
// @Router /v1/status [GET] // @Router /v1/status [GET]
func (ctrl *V1Controller) HandleBase(ready ReadyFunc, build Build) http.HandlerFunc { func (ctrl *V1Controller) HandleBase(ready ReadyFunc, build Build) server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
server.Respond(w, http.StatusOK, ApiSummary{ return server.Respond(w, http.StatusOK, ApiSummary{
Healthy: ready(), Healthy: ready(),
Title: "Go API Template", Title: "Go API Template",
Message: "Welcome to the Go API Template Application!", Message: "Welcome to the Go API Template Application!",

View file

@ -5,7 +5,7 @@ import (
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -13,8 +13,7 @@ func (ctrl *V1Controller) routeID(w http.ResponseWriter, r *http.Request) (uuid.
ID, err := uuid.Parse(chi.URLParam(r, "id")) ID, err := uuid.Parse(chi.URLParam(r, "id"))
if err != nil { if err != nil {
log.Err(err).Msg("failed to parse id") log.Err(err).Msg("failed to parse id")
server.RespondError(w, http.StatusBadRequest, err) return uuid.Nil, validate.ErrInvalidID
return uuid.Nil, err
} }
return ID, nil return ID, nil

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -32,17 +33,17 @@ type (
// @Produce json // @Produce json
// @Success 200 {object} TokenResponse // @Success 200 {object} TokenResponse
// @Router /v1/users/login [POST] // @Router /v1/users/login [POST]
func (ctrl *V1Controller) HandleAuthLogin() http.HandlerFunc { func (ctrl *V1Controller) HandleAuthLogin() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
loginForm := &LoginForm{} loginForm := &LoginForm{}
switch r.Header.Get("Content-Type") { switch r.Header.Get("Content-Type") {
case server.ContentFormUrlEncoded: case server.ContentFormUrlEncoded:
err := r.ParseForm() err := r.ParseForm()
if err != nil { if err != nil {
server.Respond(w, http.StatusBadRequest, server.Wrap(err)) return server.Respond(w, http.StatusBadRequest, server.Wrap(err))
log.Error().Err(err).Msg("failed to parse form") log.Error().Err(err).Msg("failed to parse form")
return return nil
} }
loginForm.Username = r.PostFormValue("username") loginForm.Username = r.PostFormValue("username")
@ -52,27 +53,34 @@ func (ctrl *V1Controller) HandleAuthLogin() http.HandlerFunc {
if err != nil { if err != nil {
log.Err(err).Msg("failed to decode login form") log.Err(err).Msg("failed to decode login form")
server.Respond(w, http.StatusBadRequest, server.Wrap(err)) return server.Respond(w, http.StatusBadRequest, server.Wrap(err))
return return nil
} }
default: default:
server.Respond(w, http.StatusBadRequest, errors.New("invalid content type")) return server.Respond(w, http.StatusBadRequest, errors.New("invalid content type"))
return return nil
} }
if loginForm.Username == "" || loginForm.Password == "" { if loginForm.Username == "" || loginForm.Password == "" {
server.RespondError(w, http.StatusBadRequest, errors.New("username and password are required")) return validate.NewFieldErrors(
return validate.FieldError{
Field: "username",
Error: "username or password is empty",
},
validate.FieldError{
Field: "password",
Error: "username or password is empty",
},
)
} }
newToken, err := ctrl.svc.User.Login(r.Context(), loginForm.Username, loginForm.Password) newToken, err := ctrl.svc.User.Login(r.Context(), loginForm.Username, loginForm.Password)
if err != nil { if err != nil {
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(errors.New("authentication failed"), http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, TokenResponse{ return server.Respond(w, http.StatusOK, TokenResponse{
Token: "Bearer " + newToken.Raw, Token: "Bearer " + newToken.Raw,
ExpiresAt: newToken.ExpiresAt, ExpiresAt: newToken.ExpiresAt,
}) })
@ -85,23 +93,19 @@ func (ctrl *V1Controller) HandleAuthLogin() http.HandlerFunc {
// @Success 204 // @Success 204
// @Router /v1/users/logout [POST] // @Router /v1/users/logout [POST]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleAuthLogout() http.HandlerFunc { func (ctrl *V1Controller) HandleAuthLogout() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
token := services.UseTokenCtx(r.Context()) token := services.UseTokenCtx(r.Context())
if token == "" { if token == "" {
server.RespondError(w, http.StatusUnauthorized, errors.New("no token within request context")) return validate.NewRequestError(errors.New("no token within request context"), http.StatusUnauthorized)
return
} }
err := ctrl.svc.User.Logout(r.Context(), token) err := ctrl.svc.User.Logout(r.Context(), token)
if err != nil { if err != nil {
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
} }
} }
@ -113,22 +117,18 @@ func (ctrl *V1Controller) HandleAuthLogout() http.HandlerFunc {
// @Success 200 // @Success 200
// @Router /v1/users/refresh [GET] // @Router /v1/users/refresh [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleAuthRefresh() http.HandlerFunc { func (ctrl *V1Controller) HandleAuthRefresh() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
requestToken := services.UseTokenCtx(r.Context()) requestToken := services.UseTokenCtx(r.Context())
if requestToken == "" { if requestToken == "" {
server.RespondError(w, http.StatusUnauthorized, errors.New("no user token found")) return validate.NewRequestError(errors.New("no token within request context"), http.StatusUnauthorized)
return
} }
newToken, err := ctrl.svc.User.RenewToken(r.Context(), requestToken) newToken, err := ctrl.svc.User.RenewToken(r.Context(), requestToken)
if err != nil { if err != nil {
server.RespondUnauthorized(w) return validate.UnauthorizedError()
return
} }
server.Respond(w, http.StatusOK, newToken) return server.Respond(w, http.StatusOK, newToken)
} }
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/repo"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -31,7 +32,7 @@ type (
// @Success 200 {object} repo.Group // @Success 200 {object} repo.Group
// @Router /v1/groups [Get] // @Router /v1/groups [Get]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleGroupGet() http.HandlerFunc { func (ctrl *V1Controller) HandleGroupGet() server.HandlerFunc {
return ctrl.handleGroupGeneral() return ctrl.handleGroupGeneral()
} }
@ -43,12 +44,12 @@ func (ctrl *V1Controller) HandleGroupGet() http.HandlerFunc {
// @Success 200 {object} repo.Group // @Success 200 {object} repo.Group
// @Router /v1/groups [Put] // @Router /v1/groups [Put]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleGroupUpdate() http.HandlerFunc { func (ctrl *V1Controller) HandleGroupUpdate() server.HandlerFunc {
return ctrl.handleGroupGeneral() return ctrl.handleGroupGeneral()
} }
func (ctrl *V1Controller) handleGroupGeneral() http.HandlerFunc { func (ctrl *V1Controller) handleGroupGeneral() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
ctx := services.NewContext(r.Context()) ctx := services.NewContext(r.Context())
switch r.Method { switch r.Method {
@ -56,29 +57,28 @@ func (ctrl *V1Controller) handleGroupGeneral() http.HandlerFunc {
group, err := ctrl.svc.Group.Get(ctx) group, err := ctrl.svc.Group.Get(ctx)
if err != nil { if err != nil {
log.Err(err).Msg("failed to get group") log.Err(err).Msg("failed to get group")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
group.Currency = strings.ToUpper(group.Currency) // TODO: Hack to fix the currency enums being lower caseÍ group.Currency = strings.ToUpper(group.Currency) // TODO: Hack to fix the currency enums being lower caseÍ
server.Respond(w, http.StatusOK, group) return server.Respond(w, http.StatusOK, group)
case http.MethodPut: case http.MethodPut:
data := repo.GroupUpdate{} data := repo.GroupUpdate{}
if err := server.Decode(r, &data); err != nil { if err := server.Decode(r, &data); err != nil {
server.RespondError(w, http.StatusBadRequest, err) return validate.NewRequestError(err, http.StatusBadRequest)
return
} }
group, err := ctrl.svc.Group.UpdateGroup(ctx, data) group, err := ctrl.svc.Group.UpdateGroup(ctx, data)
if err != nil { if err != nil {
log.Err(err).Msg("failed to update group") log.Err(err).Msg("failed to update group")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
group.Currency = strings.ToUpper(group.Currency) // TODO: Hack to fix the currency enums being lower case group.Currency = strings.ToUpper(group.Currency) // TODO: Hack to fix the currency enums being lower case
server.Respond(w, http.StatusOK, group) return server.Respond(w, http.StatusOK, group)
} }
return nil
} }
} }
@ -90,13 +90,12 @@ func (ctrl *V1Controller) handleGroupGeneral() http.HandlerFunc {
// @Success 200 {object} GroupInvitation // @Success 200 {object} GroupInvitation
// @Router /v1/groups/invitations [Post] // @Router /v1/groups/invitations [Post]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleGroupInvitationsCreate() http.HandlerFunc { func (ctrl *V1Controller) HandleGroupInvitationsCreate() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
data := GroupInvitationCreate{} data := GroupInvitationCreate{}
if err := server.Decode(r, &data); err != nil { if err := server.Decode(r, &data); err != nil {
log.Err(err).Msg("failed to decode user registration data") log.Err(err).Msg("failed to decode user registration data")
server.RespondError(w, http.StatusBadRequest, err) return validate.NewRequestError(err, http.StatusBadRequest)
return
} }
if data.ExpiresAt.IsZero() { if data.ExpiresAt.IsZero() {
@ -108,11 +107,10 @@ func (ctrl *V1Controller) HandleGroupInvitationsCreate() http.HandlerFunc {
token, err := ctrl.svc.Group.NewInvitation(ctx, data.Uses, data.ExpiresAt) token, err := ctrl.svc.Group.NewInvitation(ctx, data.Uses, data.ExpiresAt)
if err != nil { if err != nil {
log.Err(err).Msg("failed to create new token") log.Err(err).Msg("failed to create new token")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusCreated, GroupInvitation{ return server.Respond(w, http.StatusCreated, GroupInvitation{
Token: token, Token: token,
ExpiresAt: data.ExpiresAt, ExpiresAt: data.ExpiresAt,
Uses: data.Uses, Uses: data.Uses,

View file

@ -9,6 +9,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/repo"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -25,7 +26,7 @@ import (
// @Success 200 {object} repo.PaginationResult[repo.ItemSummary]{} // @Success 200 {object} repo.PaginationResult[repo.ItemSummary]{}
// @Router /v1/items [GET] // @Router /v1/items [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemsGetAll() http.HandlerFunc { func (ctrl *V1Controller) HandleItemsGetAll() server.HandlerFunc {
uuidList := func(params url.Values, key string) []uuid.UUID { uuidList := func(params url.Values, key string) []uuid.UUID {
var ids []uuid.UUID var ids []uuid.UUID
for _, id := range params[key] { for _, id := range params[key] {
@ -58,15 +59,14 @@ func (ctrl *V1Controller) HandleItemsGetAll() http.HandlerFunc {
} }
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
ctx := services.NewContext(r.Context()) ctx := services.NewContext(r.Context())
items, err := ctrl.svc.Items.Query(ctx, extractQuery(r)) items, err := ctrl.svc.Items.Query(ctx, extractQuery(r))
if err != nil { if err != nil {
log.Err(err).Msg("failed to get items") log.Err(err).Msg("failed to get items")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, items) return server.Respond(w, http.StatusOK, items)
} }
} }
@ -78,24 +78,22 @@ func (ctrl *V1Controller) HandleItemsGetAll() http.HandlerFunc {
// @Success 200 {object} repo.ItemSummary // @Success 200 {object} repo.ItemSummary
// @Router /v1/items [POST] // @Router /v1/items [POST]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemsCreate() http.HandlerFunc { func (ctrl *V1Controller) HandleItemsCreate() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
createData := repo.ItemCreate{} createData := repo.ItemCreate{}
if err := server.Decode(r, &createData); err != nil { if err := server.Decode(r, &createData); err != nil {
log.Err(err).Msg("failed to decode request body") log.Err(err).Msg("failed to decode request body")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
user := services.UseUserCtx(r.Context()) user := services.UseUserCtx(r.Context())
item, err := ctrl.svc.Items.Create(r.Context(), user.GroupID, createData) item, err := ctrl.svc.Items.Create(r.Context(), user.GroupID, createData)
if err != nil { if err != nil {
log.Err(err).Msg("failed to create item") log.Err(err).Msg("failed to create item")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusCreated, item) return server.Respond(w, http.StatusCreated, item)
} }
} }
@ -107,7 +105,7 @@ func (ctrl *V1Controller) HandleItemsCreate() http.HandlerFunc {
// @Success 200 {object} repo.ItemOut // @Success 200 {object} repo.ItemOut
// @Router /v1/items/{id} [GET] // @Router /v1/items/{id} [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemGet() http.HandlerFunc { func (ctrl *V1Controller) HandleItemGet() server.HandlerFunc {
return ctrl.handleItemsGeneral() return ctrl.handleItemsGeneral()
} }
@ -119,7 +117,7 @@ func (ctrl *V1Controller) HandleItemGet() http.HandlerFunc {
// @Success 204 // @Success 204
// @Router /v1/items/{id} [DELETE] // @Router /v1/items/{id} [DELETE]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemDelete() http.HandlerFunc { func (ctrl *V1Controller) HandleItemDelete() server.HandlerFunc {
return ctrl.handleItemsGeneral() return ctrl.handleItemsGeneral()
} }
@ -132,16 +130,15 @@ func (ctrl *V1Controller) HandleItemDelete() http.HandlerFunc {
// @Success 200 {object} repo.ItemOut // @Success 200 {object} repo.ItemOut
// @Router /v1/items/{id} [PUT] // @Router /v1/items/{id} [PUT]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemUpdate() http.HandlerFunc { func (ctrl *V1Controller) HandleItemUpdate() server.HandlerFunc {
return ctrl.handleItemsGeneral() return ctrl.handleItemsGeneral()
} }
func (ctrl *V1Controller) handleItemsGeneral() http.HandlerFunc { func (ctrl *V1Controller) handleItemsGeneral() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
ctx := services.NewContext(r.Context()) ctx := services.NewContext(r.Context())
ID, err := ctrl.routeID(w, r) ID, err := ctrl.routeID(w, r)
if err != nil { if err != nil {
return
} }
switch r.Method { switch r.Method {
@ -149,37 +146,32 @@ func (ctrl *V1Controller) handleItemsGeneral() http.HandlerFunc {
items, err := ctrl.svc.Items.GetOne(r.Context(), ctx.GID, ID) items, err := ctrl.svc.Items.GetOne(r.Context(), ctx.GID, ID)
if err != nil { if err != nil {
log.Err(err).Msg("failed to get item") log.Err(err).Msg("failed to get item")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, items) return server.Respond(w, http.StatusOK, items)
return
case http.MethodDelete: case http.MethodDelete:
err = ctrl.svc.Items.Delete(r.Context(), ctx.GID, ID) err = ctrl.svc.Items.Delete(r.Context(), ctx.GID, ID)
if err != nil { if err != nil {
log.Err(err).Msg("failed to delete item") log.Err(err).Msg("failed to delete item")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
return
case http.MethodPut: case http.MethodPut:
body := repo.ItemUpdate{} body := repo.ItemUpdate{}
if err := server.Decode(r, &body); err != nil { if err := server.Decode(r, &body); err != nil {
log.Err(err).Msg("failed to decode request body") log.Err(err).Msg("failed to decode request body")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
body.ID = ID body.ID = ID
result, err := ctrl.svc.Items.Update(r.Context(), ctx.GID, body) result, err := ctrl.svc.Items.Update(r.Context(), ctx.GID, body)
if err != nil { if err != nil {
log.Err(err).Msg("failed to update item") log.Err(err).Msg("failed to update item")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, result) return server.Respond(w, http.StatusOK, result)
} }
return nil
} }
} }
@ -191,29 +183,26 @@ func (ctrl *V1Controller) handleItemsGeneral() http.HandlerFunc {
// @Param csv formData file true "Image to upload" // @Param csv formData file true "Image to upload"
// @Router /v1/items/import [Post] // @Router /v1/items/import [Post]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemsImport() http.HandlerFunc { func (ctrl *V1Controller) HandleItemsImport() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
err := r.ParseMultipartForm(ctrl.maxUploadSize << 20) err := r.ParseMultipartForm(ctrl.maxUploadSize << 20)
if err != nil { if err != nil {
log.Err(err).Msg("failed to parse multipart form") log.Err(err).Msg("failed to parse multipart form")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
file, _, err := r.FormFile("csv") file, _, err := r.FormFile("csv")
if err != nil { if err != nil {
log.Err(err).Msg("failed to get file from form") log.Err(err).Msg("failed to get file from form")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
reader := csv.NewReader(file) reader := csv.NewReader(file)
data, err := reader.ReadAll() data, err := reader.ReadAll()
if err != nil { if err != nil {
log.Err(err).Msg("failed to read csv") log.Err(err).Msg("failed to read csv")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
user := services.UseUserCtx(r.Context()) user := services.UseUserCtx(r.Context())
@ -221,10 +210,9 @@ func (ctrl *V1Controller) HandleItemsImport() http.HandlerFunc {
_, err = ctrl.svc.Items.CsvImport(r.Context(), user.GroupID, data) _, err = ctrl.svc.Items.CsvImport(r.Context(), user.GroupID, data)
if err != nil { if err != nil {
log.Err(err).Msg("failed to import items") log.Err(err).Msg("failed to import items")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
} }
} }

View file

@ -10,6 +10,7 @@ import (
"github.com/hay-kot/homebox/backend/ent/attachment" "github.com/hay-kot/homebox/backend/ent/attachment"
"github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/repo"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -32,13 +33,13 @@ type (
// @Failure 422 {object} []server.ValidationError // @Failure 422 {object} []server.ValidationError
// @Router /v1/items/{id}/attachments [POST] // @Router /v1/items/{id}/attachments [POST]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc { func (ctrl *V1Controller) HandleItemAttachmentCreate() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
err := r.ParseMultipartForm(ctrl.maxUploadSize << 20) err := r.ParseMultipartForm(ctrl.maxUploadSize << 20)
if err != nil { if err != nil {
log.Err(err).Msg("failed to parse multipart form") log.Err(err).Msg("failed to parse multipart form")
server.RespondError(w, http.StatusBadRequest, errors.New("failed to parse multipart form")) return validate.NewRequestError(errors.New("failed to parse multipart form"), http.StatusBadRequest)
return
} }
errs := make(server.ValidationErrors, 0) errs := make(server.ValidationErrors, 0)
@ -51,8 +52,7 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc {
errs = errs.Append("file", "file is required") errs = errs.Append("file", "file is required")
default: default:
log.Err(err).Msg("failed to get file from form") log.Err(err).Msg("failed to get file from form")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
} }
@ -63,8 +63,7 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc {
} }
if errs.HasErrors() { if errs.HasErrors() {
server.Respond(w, http.StatusUnprocessableEntity, errs) return server.Respond(w, http.StatusUnprocessableEntity, errs)
return
} }
attachmentType := r.FormValue("type") attachmentType := r.FormValue("type")
@ -74,7 +73,6 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc {
id, err := ctrl.routeID(w, r) id, err := ctrl.routeID(w, r)
if err != nil { if err != nil {
return
} }
ctx := services.NewContext(r.Context()) ctx := services.NewContext(r.Context())
@ -89,11 +87,10 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc {
if err != nil { if err != nil {
log.Err(err).Msg("failed to add attachment") log.Err(err).Msg("failed to add attachment")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusCreated, item) return server.Respond(w, http.StatusCreated, item)
} }
} }
@ -106,21 +103,21 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc {
// @Success 200 // @Success 200
// @Router /v1/items/{id}/attachments/download [GET] // @Router /v1/items/{id}/attachments/download [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemAttachmentDownload() http.HandlerFunc { func (ctrl *V1Controller) HandleItemAttachmentDownload() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
token := server.GetParam(r, "token", "") token := server.GetParam(r, "token", "")
doc, err := ctrl.svc.Items.AttachmentPath(r.Context(), token) doc, err := ctrl.svc.Items.AttachmentPath(r.Context(), token)
if err != nil { if err != nil {
log.Err(err).Msg("failed to get attachment") log.Err(err).Msg("failed to get attachment")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", doc.Title)) w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", doc.Title))
w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Type", "application/octet-stream")
http.ServeFile(w, r, doc.Path) http.ServeFile(w, r, doc.Path)
return nil
} }
} }
@ -133,7 +130,7 @@ func (ctrl *V1Controller) HandleItemAttachmentDownload() http.HandlerFunc {
// @Success 200 {object} ItemAttachmentToken // @Success 200 {object} ItemAttachmentToken
// @Router /v1/items/{id}/attachments/{attachment_id} [GET] // @Router /v1/items/{id}/attachments/{attachment_id} [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemAttachmentToken() http.HandlerFunc { func (ctrl *V1Controller) HandleItemAttachmentToken() server.HandlerFunc {
return ctrl.handleItemAttachmentsHandler return ctrl.handleItemAttachmentsHandler
} }
@ -145,7 +142,7 @@ func (ctrl *V1Controller) HandleItemAttachmentToken() http.HandlerFunc {
// @Success 204 // @Success 204
// @Router /v1/items/{id}/attachments/{attachment_id} [DELETE] // @Router /v1/items/{id}/attachments/{attachment_id} [DELETE]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemAttachmentDelete() http.HandlerFunc { func (ctrl *V1Controller) HandleItemAttachmentDelete() server.HandlerFunc {
return ctrl.handleItemAttachmentsHandler return ctrl.handleItemAttachmentsHandler
} }
@ -158,21 +155,19 @@ func (ctrl *V1Controller) HandleItemAttachmentDelete() http.HandlerFunc {
// @Success 200 {object} repo.ItemOut // @Success 200 {object} repo.ItemOut
// @Router /v1/items/{id}/attachments/{attachment_id} [PUT] // @Router /v1/items/{id}/attachments/{attachment_id} [PUT]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleItemAttachmentUpdate() http.HandlerFunc { func (ctrl *V1Controller) HandleItemAttachmentUpdate() server.HandlerFunc {
return ctrl.handleItemAttachmentsHandler return ctrl.handleItemAttachmentsHandler
} }
func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r *http.Request) { func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r *http.Request) error {
ID, err := ctrl.routeID(w, r) ID, err := ctrl.routeID(w, r)
if err != nil { if err != nil {
return
} }
attachmentId, err := uuid.Parse(chi.URLParam(r, "attachment_id")) attachmentId, err := uuid.Parse(chi.URLParam(r, "attachment_id"))
if err != nil { if err != nil {
log.Err(err).Msg("failed to parse attachment_id param") log.Err(err).Msg("failed to parse attachment_id param")
server.RespondError(w, http.StatusBadRequest, err) return validate.NewRequestError(err, http.StatusBadRequest)
return
} }
ctx := services.NewContext(r.Context()) ctx := services.NewContext(r.Context())
@ -189,7 +184,7 @@ func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r
Str("id", attachmentId.String()). Str("id", attachmentId.String()).
Msg("failed to find attachment with id") Msg("failed to find attachment with id")
server.RespondError(w, http.StatusNotFound, err) return validate.NewRequestError(err, http.StatusNotFound)
case services.ErrFileNotFound: case services.ErrFileNotFound:
log.Err(err). log.Err(err).
@ -197,27 +192,25 @@ func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r
Msg("failed to find file path for attachment with id") Msg("failed to find file path for attachment with id")
log.Warn().Msg("attachment with no file path removed from database") log.Warn().Msg("attachment with no file path removed from database")
server.RespondError(w, http.StatusNotFound, err) return validate.NewRequestError(err, http.StatusNotFound)
default: default:
log.Err(err).Msg("failed to get attachment") log.Err(err).Msg("failed to get attachment")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
} }
server.Respond(w, http.StatusOK, ItemAttachmentToken{Token: token}) return server.Respond(w, http.StatusOK, ItemAttachmentToken{Token: token})
// Delete Attachment Handler // Delete Attachment Handler
case http.MethodDelete: case http.MethodDelete:
err = ctrl.svc.Items.AttachmentDelete(r.Context(), ctx.GID, ID, attachmentId) err = ctrl.svc.Items.AttachmentDelete(r.Context(), ctx.GID, ID, attachmentId)
if err != nil { if err != nil {
log.Err(err).Msg("failed to delete attachment") log.Err(err).Msg("failed to delete attachment")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
// Update Attachment Handler // Update Attachment Handler
case http.MethodPut: case http.MethodPut:
@ -225,18 +218,18 @@ func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r
err = server.Decode(r, &attachment) err = server.Decode(r, &attachment)
if err != nil { if err != nil {
log.Err(err).Msg("failed to decode attachment") log.Err(err).Msg("failed to decode attachment")
server.RespondError(w, http.StatusBadRequest, err) return validate.NewRequestError(err, http.StatusBadRequest)
return
} }
attachment.ID = attachmentId attachment.ID = attachmentId
val, err := ctrl.svc.Items.AttachmentUpdate(ctx, ID, &attachment) val, err := ctrl.svc.Items.AttachmentUpdate(ctx, ID, &attachment)
if err != nil { if err != nil {
log.Err(err).Msg("failed to delete attachment") log.Err(err).Msg("failed to delete attachment")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, val) return server.Respond(w, http.StatusOK, val)
} }
return nil
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/hay-kot/homebox/backend/ent" "github.com/hay-kot/homebox/backend/ent"
"github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/repo"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -17,16 +18,15 @@ import (
// @Success 200 {object} server.Results{items=[]repo.LabelOut} // @Success 200 {object} server.Results{items=[]repo.LabelOut}
// @Router /v1/labels [GET] // @Router /v1/labels [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLabelsGetAll() http.HandlerFunc { func (ctrl *V1Controller) HandleLabelsGetAll() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
user := services.UseUserCtx(r.Context()) user := services.UseUserCtx(r.Context())
labels, err := ctrl.svc.Labels.GetAll(r.Context(), user.GroupID) labels, err := ctrl.svc.Labels.GetAll(r.Context(), user.GroupID)
if err != nil { if err != nil {
log.Err(err).Msg("error getting labels") log.Err(err).Msg("error getting labels")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, server.Results{Items: labels}) return server.Respond(w, http.StatusOK, server.Results{Items: labels})
} }
} }
@ -38,24 +38,22 @@ func (ctrl *V1Controller) HandleLabelsGetAll() http.HandlerFunc {
// @Success 200 {object} repo.LabelSummary // @Success 200 {object} repo.LabelSummary
// @Router /v1/labels [POST] // @Router /v1/labels [POST]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLabelsCreate() http.HandlerFunc { func (ctrl *V1Controller) HandleLabelsCreate() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
createData := repo.LabelCreate{} createData := repo.LabelCreate{}
if err := server.Decode(r, &createData); err != nil { if err := server.Decode(r, &createData); err != nil {
log.Err(err).Msg("error decoding label create data") log.Err(err).Msg("error decoding label create data")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
user := services.UseUserCtx(r.Context()) user := services.UseUserCtx(r.Context())
label, err := ctrl.svc.Labels.Create(r.Context(), user.GroupID, createData) label, err := ctrl.svc.Labels.Create(r.Context(), user.GroupID, createData)
if err != nil { if err != nil {
log.Err(err).Msg("error creating label") log.Err(err).Msg("error creating label")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusCreated, label) return server.Respond(w, http.StatusCreated, label)
} }
} }
@ -67,7 +65,7 @@ func (ctrl *V1Controller) HandleLabelsCreate() http.HandlerFunc {
// @Success 204 // @Success 204
// @Router /v1/labels/{id} [DELETE] // @Router /v1/labels/{id} [DELETE]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLabelDelete() http.HandlerFunc { func (ctrl *V1Controller) HandleLabelDelete() server.HandlerFunc {
return ctrl.handleLabelsGeneral() return ctrl.handleLabelsGeneral()
} }
@ -79,7 +77,7 @@ func (ctrl *V1Controller) HandleLabelDelete() http.HandlerFunc {
// @Success 200 {object} repo.LabelOut // @Success 200 {object} repo.LabelOut
// @Router /v1/labels/{id} [GET] // @Router /v1/labels/{id} [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLabelGet() http.HandlerFunc { func (ctrl *V1Controller) HandleLabelGet() server.HandlerFunc {
return ctrl.handleLabelsGeneral() return ctrl.handleLabelsGeneral()
} }
@ -91,16 +89,15 @@ func (ctrl *V1Controller) HandleLabelGet() http.HandlerFunc {
// @Success 200 {object} repo.LabelOut // @Success 200 {object} repo.LabelOut
// @Router /v1/labels/{id} [PUT] // @Router /v1/labels/{id} [PUT]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLabelUpdate() http.HandlerFunc { func (ctrl *V1Controller) HandleLabelUpdate() server.HandlerFunc {
return ctrl.handleLabelsGeneral() return ctrl.handleLabelsGeneral()
} }
func (ctrl *V1Controller) handleLabelsGeneral() http.HandlerFunc { func (ctrl *V1Controller) handleLabelsGeneral() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
ctx := services.NewContext(r.Context()) ctx := services.NewContext(r.Context())
ID, err := ctrl.routeID(w, r) ID, err := ctrl.routeID(w, r)
if err != nil { if err != nil {
return
} }
switch r.Method { switch r.Method {
@ -111,40 +108,37 @@ func (ctrl *V1Controller) handleLabelsGeneral() http.HandlerFunc {
log.Err(err). log.Err(err).
Str("id", ID.String()). Str("id", ID.String()).
Msg("label not found") Msg("label not found")
server.RespondError(w, http.StatusNotFound, err) return validate.NewRequestError(err, http.StatusNotFound)
return
} }
log.Err(err).Msg("error getting label") log.Err(err).Msg("error getting label")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, labels) return server.Respond(w, http.StatusOK, labels)
case http.MethodDelete: case http.MethodDelete:
err = ctrl.svc.Labels.Delete(r.Context(), ctx.GID, ID) err = ctrl.svc.Labels.Delete(r.Context(), ctx.GID, ID)
if err != nil { if err != nil {
log.Err(err).Msg("error deleting label") log.Err(err).Msg("error deleting label")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
case http.MethodPut: case http.MethodPut:
body := repo.LabelUpdate{} body := repo.LabelUpdate{}
if err := server.Decode(r, &body); err != nil { if err := server.Decode(r, &body); err != nil {
log.Err(err).Msg("error decoding label update data") log.Err(err).Msg("error decoding label update data")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
body.ID = ID body.ID = ID
result, err := ctrl.svc.Labels.Update(r.Context(), ctx.GID, body) result, err := ctrl.svc.Labels.Update(r.Context(), ctx.GID, body)
if err != nil { if err != nil {
log.Err(err).Msg("error updating label") log.Err(err).Msg("error updating label")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, result) return server.Respond(w, http.StatusOK, result)
} }
return nil
} }
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/hay-kot/homebox/backend/ent" "github.com/hay-kot/homebox/backend/ent"
"github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/repo"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -17,17 +18,16 @@ import (
// @Success 200 {object} server.Results{items=[]repo.LocationOutCount} // @Success 200 {object} server.Results{items=[]repo.LocationOutCount}
// @Router /v1/locations [GET] // @Router /v1/locations [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLocationGetAll() http.HandlerFunc { func (ctrl *V1Controller) HandleLocationGetAll() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
user := services.UseUserCtx(r.Context()) user := services.UseUserCtx(r.Context())
locations, err := ctrl.svc.Location.GetAll(r.Context(), user.GroupID) locations, err := ctrl.svc.Location.GetAll(r.Context(), user.GroupID)
if err != nil { if err != nil {
log.Err(err).Msg("failed to get locations") log.Err(err).Msg("failed to get locations")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, server.Results{Items: locations}) return server.Respond(w, http.StatusOK, server.Results{Items: locations})
} }
} }
@ -39,24 +39,22 @@ func (ctrl *V1Controller) HandleLocationGetAll() http.HandlerFunc {
// @Success 200 {object} repo.LocationSummary // @Success 200 {object} repo.LocationSummary
// @Router /v1/locations [POST] // @Router /v1/locations [POST]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLocationCreate() http.HandlerFunc { func (ctrl *V1Controller) HandleLocationCreate() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
createData := repo.LocationCreate{} createData := repo.LocationCreate{}
if err := server.Decode(r, &createData); err != nil { if err := server.Decode(r, &createData); err != nil {
log.Err(err).Msg("failed to decode location create data") log.Err(err).Msg("failed to decode location create data")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
user := services.UseUserCtx(r.Context()) user := services.UseUserCtx(r.Context())
location, err := ctrl.svc.Location.Create(r.Context(), user.GroupID, createData) location, err := ctrl.svc.Location.Create(r.Context(), user.GroupID, createData)
if err != nil { if err != nil {
log.Err(err).Msg("failed to create location") log.Err(err).Msg("failed to create location")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusCreated, location) return server.Respond(w, http.StatusCreated, location)
} }
} }
@ -68,7 +66,7 @@ func (ctrl *V1Controller) HandleLocationCreate() http.HandlerFunc {
// @Success 204 // @Success 204
// @Router /v1/locations/{id} [DELETE] // @Router /v1/locations/{id} [DELETE]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLocationDelete() http.HandlerFunc { func (ctrl *V1Controller) HandleLocationDelete() server.HandlerFunc {
return ctrl.handleLocationGeneral() return ctrl.handleLocationGeneral()
} }
@ -80,7 +78,7 @@ func (ctrl *V1Controller) HandleLocationDelete() http.HandlerFunc {
// @Success 200 {object} repo.LocationOut // @Success 200 {object} repo.LocationOut
// @Router /v1/locations/{id} [GET] // @Router /v1/locations/{id} [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLocationGet() http.HandlerFunc { func (ctrl *V1Controller) HandleLocationGet() server.HandlerFunc {
return ctrl.handleLocationGeneral() return ctrl.handleLocationGeneral()
} }
@ -93,16 +91,15 @@ func (ctrl *V1Controller) HandleLocationGet() http.HandlerFunc {
// @Success 200 {object} repo.LocationOut // @Success 200 {object} repo.LocationOut
// @Router /v1/locations/{id} [PUT] // @Router /v1/locations/{id} [PUT]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleLocationUpdate() http.HandlerFunc { func (ctrl *V1Controller) HandleLocationUpdate() server.HandlerFunc {
return ctrl.handleLocationGeneral() return ctrl.handleLocationGeneral()
} }
func (ctrl *V1Controller) handleLocationGeneral() http.HandlerFunc { func (ctrl *V1Controller) handleLocationGeneral() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
ctx := services.NewContext(r.Context()) ctx := services.NewContext(r.Context())
ID, err := ctrl.routeID(w, r) ID, err := ctrl.routeID(w, r)
if err != nil { if err != nil {
return
} }
switch r.Method { switch r.Method {
@ -115,21 +112,18 @@ func (ctrl *V1Controller) handleLocationGeneral() http.HandlerFunc {
if ent.IsNotFound(err) { if ent.IsNotFound(err) {
l.Msg("location not found") l.Msg("location not found")
server.RespondError(w, http.StatusNotFound, err) return validate.NewRequestError(err, http.StatusNotFound)
return
} }
l.Msg("failed to get location") l.Msg("failed to get location")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, location) return server.Respond(w, http.StatusOK, location)
case http.MethodPut: case http.MethodPut:
body := repo.LocationUpdate{} body := repo.LocationUpdate{}
if err := server.Decode(r, &body); err != nil { if err := server.Decode(r, &body); err != nil {
log.Err(err).Msg("failed to decode location update data") log.Err(err).Msg("failed to decode location update data")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
body.ID = ID body.ID = ID
@ -137,18 +131,17 @@ func (ctrl *V1Controller) handleLocationGeneral() http.HandlerFunc {
result, err := ctrl.svc.Location.Update(r.Context(), ctx.GID, body) result, err := ctrl.svc.Location.Update(r.Context(), ctx.GID, body)
if err != nil { if err != nil {
log.Err(err).Msg("failed to update location") log.Err(err).Msg("failed to update location")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, result) return server.Respond(w, http.StatusOK, result)
case http.MethodDelete: case http.MethodDelete:
err = ctrl.svc.Location.Delete(r.Context(), ctx.GID, ID) err = ctrl.svc.Location.Delete(r.Context(), ctx.GID, ID)
if err != nil { if err != nil {
log.Err(err).Msg("failed to delete location") log.Err(err).Msg("failed to delete location")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
} }
return nil
} }
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/repo"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -17,29 +18,26 @@ import (
// @Param payload body services.UserRegistration true "User Data" // @Param payload body services.UserRegistration true "User Data"
// @Success 204 // @Success 204
// @Router /v1/users/register [Post] // @Router /v1/users/register [Post]
func (ctrl *V1Controller) HandleUserRegistration() http.HandlerFunc { func (ctrl *V1Controller) HandleUserRegistration() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
regData := services.UserRegistration{} regData := services.UserRegistration{}
if err := server.Decode(r, &regData); err != nil { if err := server.Decode(r, &regData); err != nil {
log.Err(err).Msg("failed to decode user registration data") log.Err(err).Msg("failed to decode user registration data")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
if !ctrl.allowRegistration && regData.GroupToken == "" { if !ctrl.allowRegistration && regData.GroupToken == "" {
server.RespondError(w, http.StatusForbidden, nil) return validate.NewRequestError(nil, http.StatusForbidden)
return
} }
_, err := ctrl.svc.User.RegisterUser(r.Context(), regData) _, err := ctrl.svc.User.RegisterUser(r.Context(), regData)
if err != nil { if err != nil {
log.Err(err).Msg("failed to register user") log.Err(err).Msg("failed to register user")
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
} }
} }
@ -50,17 +48,16 @@ func (ctrl *V1Controller) HandleUserRegistration() http.HandlerFunc {
// @Success 200 {object} server.Result{item=repo.UserOut} // @Success 200 {object} server.Result{item=repo.UserOut}
// @Router /v1/users/self [GET] // @Router /v1/users/self [GET]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleUserSelf() http.HandlerFunc { func (ctrl *V1Controller) HandleUserSelf() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
token := services.UseTokenCtx(r.Context()) token := services.UseTokenCtx(r.Context())
usr, err := ctrl.svc.User.GetSelf(r.Context(), token) usr, err := ctrl.svc.User.GetSelf(r.Context(), token)
if usr.ID == uuid.Nil || err != nil { if usr.ID == uuid.Nil || err != nil {
log.Err(err).Msg("failed to get user") log.Err(err).Msg("failed to get user")
server.RespondServerError(w) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, server.Wrap(usr)) return server.Respond(w, http.StatusOK, server.Wrap(usr))
} }
} }
@ -72,24 +69,22 @@ func (ctrl *V1Controller) HandleUserSelf() http.HandlerFunc {
// @Success 200 {object} server.Result{item=repo.UserUpdate} // @Success 200 {object} server.Result{item=repo.UserUpdate}
// @Router /v1/users/self [PUT] // @Router /v1/users/self [PUT]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleUserSelfUpdate() http.HandlerFunc { func (ctrl *V1Controller) HandleUserSelfUpdate() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
updateData := repo.UserUpdate{} updateData := repo.UserUpdate{}
if err := server.Decode(r, &updateData); err != nil { if err := server.Decode(r, &updateData); err != nil {
log.Err(err).Msg("failed to decode user update data") log.Err(err).Msg("failed to decode user update data")
server.RespondError(w, http.StatusBadRequest, err) return validate.NewRequestError(err, http.StatusBadRequest)
return
} }
actor := services.UseUserCtx(r.Context()) actor := services.UseUserCtx(r.Context())
newData, err := ctrl.svc.User.UpdateSelf(r.Context(), actor.ID, updateData) newData, err := ctrl.svc.User.UpdateSelf(r.Context(), actor.ID, updateData)
if err != nil { if err != nil {
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusOK, server.Wrap(newData)) return server.Respond(w, http.StatusOK, server.Wrap(newData))
} }
} }
@ -100,20 +95,18 @@ func (ctrl *V1Controller) HandleUserSelfUpdate() http.HandlerFunc {
// @Success 204 // @Success 204
// @Router /v1/users/self [DELETE] // @Router /v1/users/self [DELETE]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleUserSelfDelete() http.HandlerFunc { func (ctrl *V1Controller) HandleUserSelfDelete() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
if ctrl.isDemo { if ctrl.isDemo {
server.RespondError(w, http.StatusForbidden, nil) return validate.NewRequestError(nil, http.StatusForbidden)
return
} }
actor := services.UseUserCtx(r.Context()) actor := services.UseUserCtx(r.Context())
if err := ctrl.svc.User.DeleteSelf(r.Context(), actor.ID); err != nil { if err := ctrl.svc.User.DeleteSelf(r.Context(), actor.ID); err != nil {
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
} }
} }
@ -131,11 +124,10 @@ type (
// @Param payload body ChangePassword true "Password Payload" // @Param payload body ChangePassword true "Password Payload"
// @Router /v1/users/change-password [PUT] // @Router /v1/users/change-password [PUT]
// @Security Bearer // @Security Bearer
func (ctrl *V1Controller) HandleUserSelfChangePassword() http.HandlerFunc { func (ctrl *V1Controller) HandleUserSelfChangePassword() server.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
if ctrl.isDemo { if ctrl.isDemo {
server.RespondError(w, http.StatusForbidden, nil) return validate.NewRequestError(nil, http.StatusForbidden)
return
} }
var cp ChangePassword var cp ChangePassword
@ -148,10 +140,9 @@ func (ctrl *V1Controller) HandleUserSelfChangePassword() http.HandlerFunc {
ok := ctrl.svc.User.ChangePassword(ctx, cp.Current, cp.New) ok := ctrl.svc.User.ChangePassword(ctx, cp.Current, cp.New)
if !ok { if !ok {
server.RespondError(w, http.StatusInternalServerError, err) return validate.NewRequestError(err, http.StatusInternalServerError)
return
} }
server.Respond(w, http.StatusNoContent, nil) return server.Respond(w, http.StatusNoContent, nil)
} }
} }

View file

@ -15,7 +15,7 @@ func (a *app) setupLogger() {
// Logger Init // Logger Init
// zerolog.TimeFieldFormat = zerolog.TimeFormatUnix // zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
if a.conf.Log.Format != config.LogFormatJSON { if a.conf.Log.Format != config.LogFormatJSON {
log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).With().Caller().Logger()
} }
log.Level(getLevel(a.conf.Log.Level)) log.Level(getLevel(a.conf.Log.Level))

View file

@ -15,6 +15,7 @@ import (
"github.com/hay-kot/homebox/backend/internal/migrations" "github.com/hay-kot/homebox/backend/internal/migrations"
"github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/repo"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/web/mid"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -114,17 +115,25 @@ func run(cfg *config.Config) error {
app.services = services.New(app.repos) app.services = services.New(app.repos)
// ========================================================================= // =========================================================================
// Start Server // Start Server\
logger := log.With().Caller().Logger()
mwLogger := mid.Logger(logger)
if app.conf.Mode == config.ModeDevelopment {
mwLogger = mid.SugarLogger(logger)
}
app.server = server.NewServer( app.server = server.NewServer(
server.WithHost(app.conf.Web.Host), server.WithHost(app.conf.Web.Host),
server.WithPort(app.conf.Web.Port), server.WithPort(app.conf.Web.Port),
server.WithMiddleware(
mwLogger,
mid.Errors(logger),
mid.Panic(app.conf.Mode == config.ModeDevelopment),
),
) )
routes := app.newRouter(app.repos) app.mountRoutes(app.repos)
if app.conf.Mode != config.ModeDevelopment {
app.logRoutes(routes)
}
log.Info().Msgf("Starting HTTP Server on %s:%s", app.server.Host, app.server.Port) log.Info().Msgf("Starting HTTP Server on %s:%s", app.server.Host, app.server.Port)
@ -163,5 +172,5 @@ func run(cfg *config.Config) error {
}() }()
} }
return app.server.Start(routes) return app.server.Start()
} }

View file

@ -1,143 +1,34 @@
package main package main
import ( import (
"fmt" "errors"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/hay-kot/homebox/backend/internal/config"
"github.com/hay-kot/homebox/backend/internal/services" "github.com/hay-kot/homebox/backend/internal/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server" "github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog/log"
) )
func (a *app) setGlobalMiddleware(r *chi.Mux) {
// =========================================================================
// Middleware
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
r.Use(mwStripTrailingSlash)
// Use struct logger in production for requests, but use
// pretty console logger in development.
if a.conf.Mode == config.ModeDevelopment {
r.Use(a.mwSummaryLogger)
} else {
r.Use(a.mwStructLogger)
}
r.Use(middleware.Recoverer)
// Set a timeout value on the request context (ctx), that will signal
// through ctx.Done() that the request has timed out and further
// processing should be stopped.
r.Use(middleware.Timeout(60 * time.Second))
}
// mwAuthToken is a middleware that will check the database for a stateful token // mwAuthToken is a middleware that will check the database for a stateful token
// and attach it to the request context with the user, or return a 401 if it doesn't exist. // and attach it to the request context with the user, or return a 401 if it doesn't exist.
func (a *app) mwAuthToken(next http.Handler) http.Handler { func (a *app) mwAuthToken(next server.Handler) server.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
requestToken := r.Header.Get("Authorization") requestToken := r.Header.Get("Authorization")
if requestToken == "" { if requestToken == "" {
server.RespondUnauthorized(w) return validate.NewRequestError(errors.New("Authorization header is required"), http.StatusUnauthorized)
return
} }
requestToken = strings.TrimPrefix(requestToken, "Bearer ") requestToken = strings.TrimPrefix(requestToken, "Bearer ")
usr, err := a.services.User.GetSelf(r.Context(), requestToken) usr, err := a.services.User.GetSelf(r.Context(), requestToken)
// Check the database for the token // Check the database for the token
if err != nil { if err != nil {
server.RespondUnauthorized(w) return validate.NewRequestError(errors.New("Authorization header is required"), http.StatusUnauthorized)
return
} }
r = r.WithContext(services.SetUserCtx(r.Context(), &usr, requestToken)) r = r.WithContext(services.SetUserCtx(r.Context(), &usr, requestToken))
return next.ServeHTTP(w, r)
next.ServeHTTP(w, r)
})
}
// mqStripTrailingSlash is a middleware that will strip trailing slashes from the request path.
func mwStripTrailingSlash(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r.URL.Path = strings.TrimSuffix(r.URL.Path, "/")
next.ServeHTTP(w, r)
})
}
type StatusRecorder struct {
http.ResponseWriter
Status int
}
func (r *StatusRecorder) WriteHeader(status int) {
r.Status = status
r.ResponseWriter.WriteHeader(status)
}
func (a *app) mwStructLogger(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
record := &StatusRecorder{ResponseWriter: w, Status: http.StatusOK}
next.ServeHTTP(record, r)
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
url := fmt.Sprintf("%s://%s%s %s", scheme, r.Host, r.RequestURI, r.Proto)
log.Info().
Str("id", middleware.GetReqID(r.Context())).
Str("method", r.Method).
Str("remote_addr", r.RemoteAddr).
Int("status", record.Status).
Msg(url)
})
}
func (a *app) mwSummaryLogger(next http.Handler) http.Handler {
bold := func(s string) string { return "\033[1m" + s + "\033[0m" }
orange := func(s string) string { return "\033[33m" + s + "\033[0m" }
aqua := func(s string) string { return "\033[36m" + s + "\033[0m" }
red := func(s string) string { return "\033[31m" + s + "\033[0m" }
green := func(s string) string { return "\033[32m" + s + "\033[0m" }
fmtCode := func(code int) string {
switch {
case code >= 500:
return red(fmt.Sprintf("%d", code))
case code >= 400:
return orange(fmt.Sprintf("%d", code))
case code >= 300:
return aqua(fmt.Sprintf("%d", code))
default:
return green(fmt.Sprintf("%d", code))
}
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
record := &StatusRecorder{ResponseWriter: w, Status: http.StatusOK}
next.ServeHTTP(record, r) // Blocks until the next handler returns.
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
url := fmt.Sprintf("%s://%s%s %s", scheme, r.Host, r.RequestURI, r.Proto)
log.Info().
Msgf("%s %s %s",
bold(orange(""+r.Method+"")),
aqua(url),
bold(fmtCode(record.Status)),
)
}) })
} }

View file

@ -15,6 +15,7 @@ import (
v1 "github.com/hay-kot/homebox/backend/app/api/handlers/v1" v1 "github.com/hay-kot/homebox/backend/app/api/handlers/v1"
_ "github.com/hay-kot/homebox/backend/app/api/static/docs" _ "github.com/hay-kot/homebox/backend/app/api/static/docs"
"github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/repo"
"github.com/hay-kot/homebox/backend/pkgs/server"
httpSwagger "github.com/swaggo/http-swagger" // http-swagger middleware httpSwagger "github.com/swaggo/http-swagger" // http-swagger middleware
) )
@ -35,80 +36,76 @@ func (a *app) debugRouter() *http.ServeMux {
} }
// registerRoutes registers all the routes for the API // registerRoutes registers all the routes for the API
func (a *app) newRouter(repos *repo.AllRepos) *chi.Mux { func (a *app) mountRoutes(repos *repo.AllRepos) {
registerMimes() registerMimes()
r := chi.NewRouter() a.server.Get("/swagger/*", server.ToHandler(httpSwagger.Handler(
a.setGlobalMiddleware(r)
r.Get("/swagger/*", httpSwagger.Handler(
httpSwagger.URL(fmt.Sprintf("%s://%s/swagger/doc.json", a.conf.Swagger.Scheme, a.conf.Swagger.Host)), httpSwagger.URL(fmt.Sprintf("%s://%s/swagger/doc.json", a.conf.Swagger.Scheme, a.conf.Swagger.Host)),
)) )))
// ========================================================================= // =========================================================================
// API Version 1 // API Version 1
v1Base := v1.BaseUrlFunc(prefix) v1Base := v1.BaseUrlFunc(prefix)
v1Ctrl := v1.NewControllerV1(a.services,
v1Ctrl := v1.NewControllerV1(
a.services,
v1.WithMaxUploadSize(a.conf.Web.MaxUploadSize), v1.WithMaxUploadSize(a.conf.Web.MaxUploadSize),
v1.WithRegistration(a.conf.AllowRegistration), v1.WithRegistration(a.conf.AllowRegistration),
v1.WithDemoStatus(a.conf.Demo), // Disable Password Change in Demo Mode v1.WithDemoStatus(a.conf.Demo), // Disable Password Change in Demo Mode
) )
r.Get(v1Base("/status"), v1Ctrl.HandleBase(func() bool { return true }, v1.Build{
a.server.Get(v1Base("/status"), v1Ctrl.HandleBase(func() bool { return true }, v1.Build{
Version: version, Version: version,
Commit: commit, Commit: commit,
BuildTime: buildTime, BuildTime: buildTime,
})) }))
r.Post(v1Base("/users/register"), v1Ctrl.HandleUserRegistration()) a.server.Post(v1Base("/users/register"), v1Ctrl.HandleUserRegistration())
r.Post(v1Base("/users/login"), v1Ctrl.HandleAuthLogin()) a.server.Post(v1Base("/users/login"), v1Ctrl.HandleAuthLogin())
// Attachment download URl needs a `token` query param to be passed in the request. // Attachment download URl needs a `token` query param to be passed in the request.
// and also needs to be outside of the `auth` middleware. // and also needs to be outside of the `auth` middleware.
r.Get(v1Base("/items/{id}/attachments/download"), v1Ctrl.HandleItemAttachmentDownload()) a.server.Get(v1Base("/items/{id}/attachments/download"), v1Ctrl.HandleItemAttachmentDownload())
r.Group(func(r chi.Router) { a.server.Get(v1Base("/users/self"), v1Ctrl.HandleUserSelf(), a.mwAuthToken)
r.Use(a.mwAuthToken) a.server.Put(v1Base("/users/self"), v1Ctrl.HandleUserSelfUpdate(), a.mwAuthToken)
r.Get(v1Base("/users/self"), v1Ctrl.HandleUserSelf()) a.server.Delete(v1Base("/users/self"), v1Ctrl.HandleUserSelfDelete(), a.mwAuthToken)
r.Put(v1Base("/users/self"), v1Ctrl.HandleUserSelfUpdate()) a.server.Post(v1Base("/users/logout"), v1Ctrl.HandleAuthLogout(), a.mwAuthToken)
r.Delete(v1Base("/users/self"), v1Ctrl.HandleUserSelfDelete()) a.server.Get(v1Base("/users/refresh"), v1Ctrl.HandleAuthRefresh(), a.mwAuthToken)
r.Post(v1Base("/users/logout"), v1Ctrl.HandleAuthLogout()) a.server.Put(v1Base("/users/self/change-password"), v1Ctrl.HandleUserSelfChangePassword(), a.mwAuthToken)
r.Get(v1Base("/users/refresh"), v1Ctrl.HandleAuthRefresh())
r.Put(v1Base("/users/self/change-password"), v1Ctrl.HandleUserSelfChangePassword())
r.Post(v1Base("/groups/invitations"), v1Ctrl.HandleGroupInvitationsCreate()) a.server.Post(v1Base("/groups/invitations"), v1Ctrl.HandleGroupInvitationsCreate(), a.mwAuthToken)
// TODO: I don't like /groups being the URL for users // TODO: I don't like /groups being the URL for users
r.Get(v1Base("/groups"), v1Ctrl.HandleGroupGet()) a.server.Get(v1Base("/groups"), v1Ctrl.HandleGroupGet(), a.mwAuthToken)
r.Put(v1Base("/groups"), v1Ctrl.HandleGroupUpdate()) a.server.Put(v1Base("/groups"), v1Ctrl.HandleGroupUpdate(), a.mwAuthToken)
r.Get(v1Base("/locations"), v1Ctrl.HandleLocationGetAll()) a.server.Get(v1Base("/locations"), v1Ctrl.HandleLocationGetAll(), a.mwAuthToken)
r.Post(v1Base("/locations"), v1Ctrl.HandleLocationCreate()) a.server.Post(v1Base("/locations"), v1Ctrl.HandleLocationCreate(), a.mwAuthToken)
r.Get(v1Base("/locations/{id}"), v1Ctrl.HandleLocationGet()) a.server.Get(v1Base("/locations/{id}"), v1Ctrl.HandleLocationGet(), a.mwAuthToken)
r.Put(v1Base("/locations/{id}"), v1Ctrl.HandleLocationUpdate()) a.server.Put(v1Base("/locations/{id}"), v1Ctrl.HandleLocationUpdate(), a.mwAuthToken)
r.Delete(v1Base("/locations/{id}"), v1Ctrl.HandleLocationDelete()) a.server.Delete(v1Base("/locations/{id}"), v1Ctrl.HandleLocationDelete(), a.mwAuthToken)
r.Get(v1Base("/labels"), v1Ctrl.HandleLabelsGetAll()) a.server.Get(v1Base("/labels"), v1Ctrl.HandleLabelsGetAll(), a.mwAuthToken)
r.Post(v1Base("/labels"), v1Ctrl.HandleLabelsCreate()) a.server.Post(v1Base("/labels"), v1Ctrl.HandleLabelsCreate(), a.mwAuthToken)
r.Get(v1Base("/labels/{id}"), v1Ctrl.HandleLabelGet()) a.server.Get(v1Base("/labels/{id}"), v1Ctrl.HandleLabelGet(), a.mwAuthToken)
r.Put(v1Base("/labels/{id}"), v1Ctrl.HandleLabelUpdate()) a.server.Put(v1Base("/labels/{id}"), v1Ctrl.HandleLabelUpdate(), a.mwAuthToken)
r.Delete(v1Base("/labels/{id}"), v1Ctrl.HandleLabelDelete()) a.server.Delete(v1Base("/labels/{id}"), v1Ctrl.HandleLabelDelete(), a.mwAuthToken)
r.Get(v1Base("/items"), v1Ctrl.HandleItemsGetAll()) a.server.Get(v1Base("/items"), v1Ctrl.HandleItemsGetAll(), a.mwAuthToken)
r.Post(v1Base("/items/import"), v1Ctrl.HandleItemsImport()) a.server.Post(v1Base("/items/import"), v1Ctrl.HandleItemsImport(), a.mwAuthToken)
r.Post(v1Base("/items"), v1Ctrl.HandleItemsCreate()) a.server.Post(v1Base("/items"), v1Ctrl.HandleItemsCreate(), a.mwAuthToken)
r.Get(v1Base("/items/{id}"), v1Ctrl.HandleItemGet()) a.server.Get(v1Base("/items/{id}"), v1Ctrl.HandleItemGet(), a.mwAuthToken)
r.Put(v1Base("/items/{id}"), v1Ctrl.HandleItemUpdate()) a.server.Put(v1Base("/items/{id}"), v1Ctrl.HandleItemUpdate(), a.mwAuthToken)
r.Delete(v1Base("/items/{id}"), v1Ctrl.HandleItemDelete()) a.server.Delete(v1Base("/items/{id}"), v1Ctrl.HandleItemDelete(), a.mwAuthToken)
r.Post(v1Base("/items/{id}/attachments"), v1Ctrl.HandleItemAttachmentCreate()) a.server.Post(v1Base("/items/{id}/attachments"), v1Ctrl.HandleItemAttachmentCreate(), a.mwAuthToken)
r.Get(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentToken()) a.server.Get(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentToken(), a.mwAuthToken)
r.Put(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentUpdate()) a.server.Put(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentUpdate(), a.mwAuthToken)
r.Delete(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentDelete()) a.server.Delete(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentDelete(), a.mwAuthToken)
})
r.NotFound(notFoundHandler()) a.server.NotFound(notFoundHandler())
return r
} }
// logRoutes logs the routes of the server that are registered within Server.registerRoutes(). This is useful for debugging. // logRoutes logs the routes of the server that are registered within Server.registerRoutes(). This is useful for debugging.
@ -146,7 +143,7 @@ func registerMimes() {
// notFoundHandler perform the main logic around handling the internal SPA embed and ensuring that // notFoundHandler perform the main logic around handling the internal SPA embed and ensuring that
// the client side routing is handled correctly. // the client side routing is handled correctly.
func notFoundHandler() http.HandlerFunc { func notFoundHandler() server.HandlerFunc {
tryRead := func(fs embed.FS, prefix, requestedPath string, w http.ResponseWriter) error { tryRead := func(fs embed.FS, prefix, requestedPath string, w http.ResponseWriter) error {
f, err := fs.Open(path.Join(prefix, requestedPath)) f, err := fs.Open(path.Join(prefix, requestedPath))
if err != nil { if err != nil {
@ -165,14 +162,16 @@ func notFoundHandler() http.HandlerFunc {
return err return err
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) error {
err := tryRead(public, "static/public", r.URL.Path, w) err := tryRead(public, "static/public", r.URL.Path, w)
if err == nil {
return
}
err = tryRead(public, "static/public", "index.html", w)
if err != nil { if err != nil {
panic(err) // Fallback to the index.html file.
// should succeed in all cases.
err = tryRead(public, "static/public", "index.html", w)
if err != nil {
return err
}
} }
return nil
} }
} }

View file

@ -0,0 +1,82 @@
package validate
import (
"encoding/json"
"errors"
)
func UnauthorizedError() error {
return errors.New("unauthorized")
}
// ErrInvalidID occurs when an ID is not in a valid form.
var ErrInvalidID = errors.New("ID is not in its proper form")
// ErrorResponse is the form used for API responses from failures in the API.
type ErrorResponse struct {
Error string `json:"error"`
Fields string `json:"fields,omitempty"`
}
// RequestError is used to pass an error during the request through the
// application with web specific context.
type RequestError struct {
Err error
Status int
Fields error
}
// NewRequestError wraps a provided error with an HTTP status code. This
// function should be used when handlers encounter expected errors.
func NewRequestError(err error, status int) error {
return &RequestError{err, status, nil}
}
func (err *RequestError) Error() string {
return err.Err.Error()
}
// IsRequestError checks if an error of type RequestError exists.
func IsRequestError(err error) bool {
var re *RequestError
return errors.As(err, &re)
}
// FieldError is used to indicate an error with a specific request field.
type FieldError struct {
Field string `json:"field"`
Error string `json:"error"`
}
// FieldErrors represents a collection of field errors.
type FieldErrors []FieldError
// Error implments the error interface.
func (fe FieldErrors) Error() string {
d, err := json.Marshal(fe)
if err != nil {
return err.Error()
}
return string(d)
}
func NewFieldErrors(errs ...FieldError) FieldErrors {
return errs
}
func IsFieldError(err error) bool {
v := FieldErrors{}
return errors.As(err, &v)
}
// Cause iterates through all the wrapped errors until the root
// error value is reached.
func Cause(err error) error {
root := err
for {
if err = errors.Unwrap(root); err == nil {
return root
}
root = err
}
}

View file

@ -0,0 +1,66 @@
package mid
import (
"errors"
"net/http"
"github.com/hay-kot/homebox/backend/ent"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog"
)
func Errors(log zerolog.Logger) server.Middleware {
return func(h server.Handler) server.Handler {
return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
err := h.ServeHTTP(w, r)
if err != nil {
var resp server.ErrorResponse
var code int
log.Err(err).
Str("trace_id", "TODO").
Msg("ERROR occurred")
switch {
case errors.Is(err, validate.ErrInvalidID):
code = http.StatusBadRequest
resp = server.ErrorResponse{
Error: "invalid id parameter",
}
case validate.IsFieldError(err):
fieldErrors := err.(validate.FieldErrors)
resp.Error = "Validation Error"
resp.Fields = map[string]string{}
for _, fieldError := range fieldErrors {
resp.Fields[fieldError.Field] = fieldError.Error
}
case validate.IsRequestError(err):
requestError := err.(*validate.RequestError)
resp.Error = requestError.Error()
code = requestError.Status
case ent.IsNotFound(err):
resp.Error = "Not Found"
code = http.StatusNotFound
default:
resp.Error = "Unknown Error"
code = http.StatusInternalServerError
}
if err := server.Respond(w, code, resp); err != nil {
return err
}
// If Showdown error, return error
if server.IsShutdownError(err) {
return err
}
}
return nil
})
}
}

View file

@ -0,0 +1,94 @@
package mid
import (
"fmt"
"net/http"
"github.com/hay-kot/homebox/backend/pkgs/server"
"github.com/rs/zerolog"
)
type StatusRecorder struct {
http.ResponseWriter
Status int
}
func (r *StatusRecorder) WriteHeader(status int) {
r.Status = status
r.ResponseWriter.WriteHeader(status)
}
func Logger(log zerolog.Logger) server.Middleware {
return func(next server.Handler) server.Handler {
return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
log.Info().
Str("trace_id", "TODO").
Str("method", r.Method).
Str("path", r.URL.Path).
Str("remove_address", r.RemoteAddr).
Msg("request started")
record := &StatusRecorder{ResponseWriter: w, Status: http.StatusOK}
err := next.ServeHTTP(record, r)
log.Info().
Str("trave_id", "TODO").
Str("method", r.Method).
Str("url", r.URL.Path).
Str("remote_address", r.RemoteAddr).
Int("status_code", record.Status).
Msg("request completed")
return err
})
}
}
func SugarLogger(log zerolog.Logger) server.Middleware {
orange := func(s string) string { return "\033[33m" + s + "\033[0m" }
aqua := func(s string) string { return "\033[36m" + s + "\033[0m" }
red := func(s string) string { return "\033[31m" + s + "\033[0m" }
green := func(s string) string { return "\033[32m" + s + "\033[0m" }
fmtCode := func(code int) string {
switch {
case code >= 500:
return red(fmt.Sprintf("%d", code))
case code >= 400:
return orange(fmt.Sprintf("%d", code))
case code >= 300:
return aqua(fmt.Sprintf("%d", code))
default:
return green(fmt.Sprintf("%d", code))
}
}
bold := func(s string) string { return "\033[1m" + s + "\033[0m" }
return func(next server.Handler) server.Handler {
return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
record := &StatusRecorder{ResponseWriter: w, Status: http.StatusOK}
err := next.ServeHTTP(record, r) // Blocks until the next handler returns.
scheme := "http"
if r.TLS != nil {
scheme = "https"
}
url := fmt.Sprintf("%s://%s%s %s", scheme, r.Host, r.RequestURI, r.Proto)
log.Info().
Msgf("%s %s %s",
bold(fmtCode(record.Status)),
bold(orange(""+r.Method+"")),
aqua(url),
)
return err
})
}
}

View file

@ -0,0 +1,33 @@
package mid
import (
"fmt"
"net/http"
"runtime/debug"
"github.com/hay-kot/homebox/backend/pkgs/server"
)
// Panic is a middleware that recovers from panics anywhere in the chain and wraps the error.
// and returns it up the middleware chain.
func Panic(develop bool) server.Middleware {
return func(h server.Handler) server.Handler {
return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (err error) {
defer func() {
if rec := recover(); rec != nil {
trace := debug.Stack()
if develop {
err = fmt.Errorf("PANIC [%v]", rec)
fmt.Printf("%s", string(trace))
} else {
err = fmt.Errorf("PANIC [%v] TRACE[%s]", rec, string(trace))
}
}
}()
return h.ServeHTTP(w, r)
})
}
}

View file

@ -0,0 +1,23 @@
package server
import "errors"
type shutdownError struct {
message string
}
func (e *shutdownError) Error() string {
return e.message
}
// ShutdownError returns an error that indicates that the server has lost
// integrity and should be shut down.
func ShutdownError(message string) error {
return &shutdownError{message}
}
// IsShutdownError returns true if the error is a shutdown error.
func IsShutdownError(err error) bool {
var e *shutdownError
return errors.As(err, &e)
}

View file

@ -0,0 +1,25 @@
package server
import (
"net/http"
)
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
return f(w, r)
}
type Handler interface {
ServeHTTP(http.ResponseWriter, *http.Request) error
}
// ToHandler converts a function to a customer implementation of the Handler interface.
// that returns an error. This wrapper around the handler function and simply
// returns the nil in all cases
func ToHandler(handler http.Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
handler.ServeHTTP(w, r)
return nil
})
}

View file

@ -0,0 +1,36 @@
package server
import (
"net/http"
"strings"
)
type Middleware func(Handler) Handler
// wrapMiddleware creates a new handler by wrapping middleware around a final
// handler. The middlewares' Handlers will be executed by requests in the order
// they are provided.
func wrapMiddleware(mw []Middleware, handler Handler) Handler {
// Loop backwards through the middleware invoking each one. Replace the
// handler with the new wrapped handler. Looping backwards ensures that the
// first middleware of the slice is the first to be executed by requests.
for i := len(mw) - 1; i >= 0; i-- {
h := mw[i]
if h != nil {
handler = h(handler)
}
}
return handler
}
// StripTrailingSlash is a middleware that will strip trailing slashes from the request path.
func StripTrailingSlash() Middleware {
return func(h Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
r.URL.Path = strings.TrimSuffix(r.URL.Path, "/")
return h.ServeHTTP(w, r)
})
}
}

View file

@ -0,0 +1,98 @@
package server
import (
"context"
"net/http"
)
type vkey int
const (
// Key is the key for the server in the request context.
key vkey = 1
)
type Values struct {
TraceID string
}
func (s *Server) toHttpHandler(handler Handler, mw ...Middleware) http.HandlerFunc {
handler = wrapMiddleware(mw, handler)
handler = wrapMiddleware(s.mw, handler)
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Add the trace ID to the context
ctx = context.WithValue(ctx, key, Values{
// TODO: Initialize a new trace ID
TraceID: "00000000-0000-0000-0000-000000000000",
})
err := handler.ServeHTTP(w, r.WithContext(ctx))
if err != nil {
if IsShutdownError(err) {
s.Shutdown("SIGTERM")
}
}
}
}
func (s *Server) handle(method, pattern string, handler Handler, mw ...Middleware) {
h := s.toHttpHandler(handler, mw...)
switch method {
case http.MethodGet:
s.mux.Get(pattern, h)
case http.MethodPost:
s.mux.Post(pattern, h)
case http.MethodPut:
s.mux.Put(pattern, h)
case http.MethodDelete:
s.mux.Delete(pattern, h)
case http.MethodPatch:
s.mux.Patch(pattern, h)
case http.MethodHead:
s.mux.Head(pattern, h)
case http.MethodOptions:
s.mux.Options(pattern, h)
}
}
func (s *Server) Handler(pattern string, handler Handler, mw ...Middleware) {
s.mux.Handle(pattern, s.toHttpHandler(handler, mw...))
}
func (s *Server) Get(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodGet, pattern, handler, mw...)
}
func (s *Server) Post(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodPost, pattern, handler, mw...)
}
func (s *Server) Put(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodPut, pattern, handler, mw...)
}
func (s *Server) Delete(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodDelete, pattern, handler, mw...)
}
func (s *Server) Patch(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodPatch, pattern, handler, mw...)
}
func (s *Server) Head(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodHead, pattern, handler, mw...)
}
func (s *Server) Options(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodOptions, pattern, handler, mw...)
}
func (s *Server) NotFound(handler Handler) {
s.mux.NotFound(s.toHttpHandler(handler))
}

View file

@ -2,16 +2,20 @@ package server
import ( import (
"encoding/json" "encoding/json"
"errors"
"net/http" "net/http"
) )
type ErrorResponse struct {
Error string `json:"error"`
Fields map[string]string `json:"fields,omitempty"`
}
// Respond converts a Go value to JSON and sends it to the client. // Respond converts a Go value to JSON and sends it to the client.
// Adapted from https://github.com/ardanlabs/service/tree/master/foundation/web // Adapted from https://github.com/ardanlabs/service/tree/master/foundation/web
func Respond(w http.ResponseWriter, statusCode int, data interface{}) { func Respond(w http.ResponseWriter, statusCode int, data interface{}) error {
if statusCode == http.StatusNoContent { if statusCode == http.StatusNoContent {
w.WriteHeader(statusCode) w.WriteHeader(statusCode)
return return nil
} }
// Convert the response value to JSON. // Convert the response value to JSON.
@ -28,31 +32,8 @@ func Respond(w http.ResponseWriter, statusCode int, data interface{}) {
// Send the result back to the client. // Send the result back to the client.
if _, err := w.Write(jsonData); err != nil { if _, err := w.Write(jsonData); err != nil {
panic(err) return err
} }
}
// ResponseError is a helper function that sends a JSON response of an error message return nil
func RespondError(w http.ResponseWriter, statusCode int, err error) {
eb := ErrorBuilder{}
eb.AddError(err)
eb.Respond(w, statusCode)
}
// RespondServerError is a wrapper around RespondError that sends a 500 internal server error. Useful for
// Sending generic errors when everything went wrong.
func RespondServerError(w http.ResponseWriter) {
RespondError(w, http.StatusInternalServerError, errors.New("internal server error"))
}
// RespondNotFound is a helper utility for responding with a generic
// "unauthorized" error.
func RespondUnauthorized(w http.ResponseWriter) {
RespondError(w, http.StatusUnauthorized, errors.New("unauthorized"))
}
// RespondForbidden is a helper utility for responding with a generic
// "forbidden" error.
func RespondForbidden(w http.ResponseWriter) {
RespondError(w, http.StatusForbidden, errors.New("forbidden"))
} }

View file

@ -1,7 +1,6 @@
package server package server
import ( import (
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -38,41 +37,3 @@ func Test_Respond_JSON(t *testing.T) {
assert.Equal(t, "application/json", recorder.Header().Get("Content-Type")) assert.Equal(t, "application/json", recorder.Header().Get("Content-Type"))
} }
func Test_RespondError(t *testing.T) {
recorder := httptest.NewRecorder()
var customError = errors.New("custom error")
RespondError(recorder, http.StatusBadRequest, customError)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"details":["custom error"], "message":"Bad Request", "error":true}`)
}
func Test_RespondInternalServerError(t *testing.T) {
recorder := httptest.NewRecorder()
RespondServerError(recorder)
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"details":["internal server error"], "message":"Internal Server Error", "error":true}`)
}
func Test_RespondUnauthorized(t *testing.T) {
recorder := httptest.NewRecorder()
RespondUnauthorized(recorder)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"details":["unauthorized"], "message":"Unauthorized", "error":true}`)
}
func Test_RespondForbidden(t *testing.T) {
recorder := httptest.NewRecorder()
RespondForbidden(recorder)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"details":["forbidden"], "message":"Forbidden", "error":true}`)
}

View file

@ -10,6 +10,8 @@ import (
"sync" "sync"
"syscall" "syscall"
"time" "time"
"github.com/go-chi/chi/v5"
) )
var ( var (
@ -22,7 +24,11 @@ type Server struct {
Port string Port string
Worker Worker Worker Worker
wg sync.WaitGroup wg sync.WaitGroup
mux *chi.Mux
// mw is the global middleware chain for the server.
mw []Middleware
started bool started bool
activeServer *http.Server activeServer *http.Server
@ -36,6 +42,7 @@ func NewServer(opts ...Option) *Server {
s := &Server{ s := &Server{
Host: "localhost", Host: "localhost",
Port: "8080", Port: "8080",
mux: chi.NewRouter(),
Worker: NewSimpleWorker(), Worker: NewSimpleWorker(),
idleTimeout: 30 * time.Second, idleTimeout: 30 * time.Second,
readTimeout: 10 * time.Second, readTimeout: 10 * time.Second,
@ -75,14 +82,14 @@ func (s *Server) Shutdown(sig string) error {
} }
func (s *Server) Start(router http.Handler) error { func (s *Server) Start() error {
if s.started { if s.started {
return ErrServerAlreadyStarted return ErrServerAlreadyStarted
} }
s.activeServer = &http.Server{ s.activeServer = &http.Server{
Addr: s.Host + ":" + s.Port, Addr: s.Host + ":" + s.Port,
Handler: router, Handler: s.mux,
IdleTimeout: s.idleTimeout, IdleTimeout: s.idleTimeout,
ReadTimeout: s.readTimeout, ReadTimeout: s.readTimeout,
WriteTimeout: s.writeTimeout, WriteTimeout: s.writeTimeout,

View file

@ -4,6 +4,13 @@ import "time"
type Option = func(s *Server) error type Option = func(s *Server) error
func WithMiddleware(mw ...Middleware) Option {
return func(s *Server) error {
s.mw = append(s.mw, mw...)
return nil
}
}
func WithWorker(w Worker) Option { func WithWorker(w Worker) Option {
return func(s *Server) error { return func(s *Server) error {
s.Worker = w s.Worker = w

View file

@ -12,8 +12,12 @@ import (
func testServer(t *testing.T, r http.Handler) *Server { func testServer(t *testing.T, r http.Handler) *Server {
svr := NewServer(WithHost("127.0.0.1"), WithPort("19245")) svr := NewServer(WithHost("127.0.0.1"), WithPort("19245"))
if r != nil {
svr.mux.Handle("/", r)
}
go func() { go func() {
err := svr.Start(r) err := svr.Start()
assert.NoError(t, err) assert.NoError(t, err)
}() }()
@ -42,7 +46,7 @@ func Test_ServerShutdown_Error(t *testing.T) {
func Test_ServerStarts_Error(t *testing.T) { func Test_ServerStarts_Error(t *testing.T) {
svr := testServer(t, nil) svr := testServer(t, nil)
err := svr.Start(nil) err := svr.Start()
assert.ErrorIs(t, err, ErrServerAlreadyStarted) assert.ErrorIs(t, err, ErrServerAlreadyStarted)
err = svr.Shutdown("test") err = svr.Shutdown("test")