From 8a3d16860c3341ba08cb913da995b1fe26cdb33e Mon Sep 17 00:00:00 2001 From: Hayden <64056131+hay-kot@users.noreply.github.com> Date: Fri, 28 Oct 2022 16:35:34 -0800 Subject: [PATCH] implement custom http handler interface --- backend/app/api/handlers/v1/controller.go | 6 +- backend/app/api/handlers/v1/partials.go | 5 +- backend/app/api/handlers/v1/v1_ctrl_auth.go | 62 ++++----- backend/app/api/handlers/v1/v1_ctrl_group.go | 36 +++-- backend/app/api/handlers/v1/v1_ctrl_items.go | 72 +++++----- .../handlers/v1/v1_ctrl_items_attachments.go | 65 +++++---- backend/app/api/handlers/v1/v1_ctrl_labels.go | 56 ++++---- .../app/api/handlers/v1/v1_ctrl_locations.go | 55 ++++---- backend/app/api/handlers/v1/v1_ctrl_user.go | 61 ++++----- backend/app/api/logger.go | 2 +- backend/app/api/main.go | 23 +++- backend/app/api/middleware.go | 123 +----------------- backend/app/api/routes.go | 105 ++++++++------- backend/internal/sys/validate/errors.go | 82 ++++++++++++ backend/internal/web/mid/errors.go | 66 ++++++++++ backend/internal/web/mid/logger.go | 94 +++++++++++++ backend/internal/web/mid/panic.go | 33 +++++ backend/pkgs/server/errors.go | 23 ++++ backend/pkgs/server/handler.go | 25 ++++ backend/pkgs/server/middleware.go | 36 +++++ backend/pkgs/server/mux.go | 98 ++++++++++++++ backend/pkgs/server/response.go | 37 ++---- backend/pkgs/server/response_test.go | 39 ------ backend/pkgs/server/server.go | 13 +- backend/pkgs/server/server_options.go | 7 + backend/pkgs/server/server_test.go | 8 +- 26 files changed, 752 insertions(+), 480 deletions(-) create mode 100644 backend/internal/sys/validate/errors.go create mode 100644 backend/internal/web/mid/errors.go create mode 100644 backend/internal/web/mid/logger.go create mode 100644 backend/internal/web/mid/panic.go create mode 100644 backend/pkgs/server/errors.go create mode 100644 backend/pkgs/server/handler.go create mode 100644 backend/pkgs/server/middleware.go create mode 100644 backend/pkgs/server/mux.go diff --git a/backend/app/api/handlers/v1/controller.go b/backend/app/api/handlers/v1/controller.go index ed1f3f7..2c41bc1 100644 --- a/backend/app/api/handlers/v1/controller.go +++ b/backend/app/api/handlers/v1/controller.go @@ -76,9 +76,9 @@ func NewControllerV1(svc *services.AllServices, options ...func(*V1Controller)) // @Produce json // @Success 200 {object} ApiSummary // @Router /v1/status [GET] -func (ctrl *V1Controller) HandleBase(ready ReadyFunc, build Build) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - server.Respond(w, http.StatusOK, ApiSummary{ +func (ctrl *V1Controller) HandleBase(ready ReadyFunc, build Build) server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { + return server.Respond(w, http.StatusOK, ApiSummary{ Healthy: ready(), Title: "Go API Template", Message: "Welcome to the Go API Template Application!", diff --git a/backend/app/api/handlers/v1/partials.go b/backend/app/api/handlers/v1/partials.go index 47249e6..e95c7a6 100644 --- a/backend/app/api/handlers/v1/partials.go +++ b/backend/app/api/handlers/v1/partials.go @@ -5,7 +5,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/google/uuid" - "github.com/hay-kot/homebox/backend/pkgs/server" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/rs/zerolog/log" ) @@ -13,8 +13,7 @@ func (ctrl *V1Controller) routeID(w http.ResponseWriter, r *http.Request) (uuid. ID, err := uuid.Parse(chi.URLParam(r, "id")) if err != nil { log.Err(err).Msg("failed to parse id") - server.RespondError(w, http.StatusBadRequest, err) - return uuid.Nil, err + return uuid.Nil, validate.ErrInvalidID } return ID, nil diff --git a/backend/app/api/handlers/v1/v1_ctrl_auth.go b/backend/app/api/handlers/v1/v1_ctrl_auth.go index d410a86..a3f49b4 100644 --- a/backend/app/api/handlers/v1/v1_ctrl_auth.go +++ b/backend/app/api/handlers/v1/v1_ctrl_auth.go @@ -6,6 +6,7 @@ import ( "time" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/pkgs/server" "github.com/rs/zerolog/log" ) @@ -32,17 +33,17 @@ type ( // @Produce json // @Success 200 {object} TokenResponse // @Router /v1/users/login [POST] -func (ctrl *V1Controller) HandleAuthLogin() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleAuthLogin() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { loginForm := &LoginForm{} switch r.Header.Get("Content-Type") { case server.ContentFormUrlEncoded: err := r.ParseForm() if err != nil { - server.Respond(w, http.StatusBadRequest, server.Wrap(err)) + return server.Respond(w, http.StatusBadRequest, server.Wrap(err)) log.Error().Err(err).Msg("failed to parse form") - return + return nil } loginForm.Username = r.PostFormValue("username") @@ -52,27 +53,34 @@ func (ctrl *V1Controller) HandleAuthLogin() http.HandlerFunc { if err != nil { log.Err(err).Msg("failed to decode login form") - server.Respond(w, http.StatusBadRequest, server.Wrap(err)) - return + return server.Respond(w, http.StatusBadRequest, server.Wrap(err)) + return nil } default: - server.Respond(w, http.StatusBadRequest, errors.New("invalid content type")) - return + return server.Respond(w, http.StatusBadRequest, errors.New("invalid content type")) + return nil } if loginForm.Username == "" || loginForm.Password == "" { - server.RespondError(w, http.StatusBadRequest, errors.New("username and password are required")) - return + return validate.NewFieldErrors( + validate.FieldError{ + Field: "username", + Error: "username or password is empty", + }, + validate.FieldError{ + Field: "password", + Error: "username or password is empty", + }, + ) } newToken, err := ctrl.svc.User.Login(r.Context(), loginForm.Username, loginForm.Password) if err != nil { - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(errors.New("authentication failed"), http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, TokenResponse{ + return server.Respond(w, http.StatusOK, TokenResponse{ Token: "Bearer " + newToken.Raw, ExpiresAt: newToken.ExpiresAt, }) @@ -85,23 +93,19 @@ func (ctrl *V1Controller) HandleAuthLogin() http.HandlerFunc { // @Success 204 // @Router /v1/users/logout [POST] // @Security Bearer -func (ctrl *V1Controller) HandleAuthLogout() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleAuthLogout() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { token := services.UseTokenCtx(r.Context()) - if token == "" { - server.RespondError(w, http.StatusUnauthorized, errors.New("no token within request context")) - return + return validate.NewRequestError(errors.New("no token within request context"), http.StatusUnauthorized) } err := ctrl.svc.User.Logout(r.Context(), token) - if err != nil { - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) + return server.Respond(w, http.StatusNoContent, nil) } } @@ -113,22 +117,18 @@ func (ctrl *V1Controller) HandleAuthLogout() http.HandlerFunc { // @Success 200 // @Router /v1/users/refresh [GET] // @Security Bearer -func (ctrl *V1Controller) HandleAuthRefresh() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleAuthRefresh() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { requestToken := services.UseTokenCtx(r.Context()) - if requestToken == "" { - server.RespondError(w, http.StatusUnauthorized, errors.New("no user token found")) - return + return validate.NewRequestError(errors.New("no token within request context"), http.StatusUnauthorized) } newToken, err := ctrl.svc.User.RenewToken(r.Context(), requestToken) - if err != nil { - server.RespondUnauthorized(w) - return + return validate.UnauthorizedError() } - server.Respond(w, http.StatusOK, newToken) + return server.Respond(w, http.StatusOK, newToken) } } diff --git a/backend/app/api/handlers/v1/v1_ctrl_group.go b/backend/app/api/handlers/v1/v1_ctrl_group.go index fe87379..a403877 100644 --- a/backend/app/api/handlers/v1/v1_ctrl_group.go +++ b/backend/app/api/handlers/v1/v1_ctrl_group.go @@ -7,6 +7,7 @@ import ( "github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/pkgs/server" "github.com/rs/zerolog/log" ) @@ -31,7 +32,7 @@ type ( // @Success 200 {object} repo.Group // @Router /v1/groups [Get] // @Security Bearer -func (ctrl *V1Controller) HandleGroupGet() http.HandlerFunc { +func (ctrl *V1Controller) HandleGroupGet() server.HandlerFunc { return ctrl.handleGroupGeneral() } @@ -43,12 +44,12 @@ func (ctrl *V1Controller) HandleGroupGet() http.HandlerFunc { // @Success 200 {object} repo.Group // @Router /v1/groups [Put] // @Security Bearer -func (ctrl *V1Controller) HandleGroupUpdate() http.HandlerFunc { +func (ctrl *V1Controller) HandleGroupUpdate() server.HandlerFunc { return ctrl.handleGroupGeneral() } -func (ctrl *V1Controller) handleGroupGeneral() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) handleGroupGeneral() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { ctx := services.NewContext(r.Context()) switch r.Method { @@ -56,29 +57,28 @@ func (ctrl *V1Controller) handleGroupGeneral() http.HandlerFunc { group, err := ctrl.svc.Group.Get(ctx) if err != nil { log.Err(err).Msg("failed to get group") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } group.Currency = strings.ToUpper(group.Currency) // TODO: Hack to fix the currency enums being lower caseƍ - server.Respond(w, http.StatusOK, group) + return server.Respond(w, http.StatusOK, group) case http.MethodPut: data := repo.GroupUpdate{} if err := server.Decode(r, &data); err != nil { - server.RespondError(w, http.StatusBadRequest, err) - return + return validate.NewRequestError(err, http.StatusBadRequest) } group, err := ctrl.svc.Group.UpdateGroup(ctx, data) if err != nil { log.Err(err).Msg("failed to update group") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } group.Currency = strings.ToUpper(group.Currency) // TODO: Hack to fix the currency enums being lower case - server.Respond(w, http.StatusOK, group) + return server.Respond(w, http.StatusOK, group) } + + return nil } } @@ -90,13 +90,12 @@ func (ctrl *V1Controller) handleGroupGeneral() http.HandlerFunc { // @Success 200 {object} GroupInvitation // @Router /v1/groups/invitations [Post] // @Security Bearer -func (ctrl *V1Controller) HandleGroupInvitationsCreate() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleGroupInvitationsCreate() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { data := GroupInvitationCreate{} if err := server.Decode(r, &data); err != nil { log.Err(err).Msg("failed to decode user registration data") - server.RespondError(w, http.StatusBadRequest, err) - return + return validate.NewRequestError(err, http.StatusBadRequest) } if data.ExpiresAt.IsZero() { @@ -108,11 +107,10 @@ func (ctrl *V1Controller) HandleGroupInvitationsCreate() http.HandlerFunc { token, err := ctrl.svc.Group.NewInvitation(ctx, data.Uses, data.ExpiresAt) if err != nil { log.Err(err).Msg("failed to create new token") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusCreated, GroupInvitation{ + return server.Respond(w, http.StatusCreated, GroupInvitation{ Token: token, ExpiresAt: data.ExpiresAt, Uses: data.Uses, diff --git a/backend/app/api/handlers/v1/v1_ctrl_items.go b/backend/app/api/handlers/v1/v1_ctrl_items.go index 49fe8b6..09cbb5b 100644 --- a/backend/app/api/handlers/v1/v1_ctrl_items.go +++ b/backend/app/api/handlers/v1/v1_ctrl_items.go @@ -9,6 +9,7 @@ import ( "github.com/google/uuid" "github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/pkgs/server" "github.com/rs/zerolog/log" ) @@ -25,7 +26,7 @@ import ( // @Success 200 {object} repo.PaginationResult[repo.ItemSummary]{} // @Router /v1/items [GET] // @Security Bearer -func (ctrl *V1Controller) HandleItemsGetAll() http.HandlerFunc { +func (ctrl *V1Controller) HandleItemsGetAll() server.HandlerFunc { uuidList := func(params url.Values, key string) []uuid.UUID { var ids []uuid.UUID for _, id := range params[key] { @@ -58,15 +59,14 @@ func (ctrl *V1Controller) HandleItemsGetAll() http.HandlerFunc { } } - return func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) error { ctx := services.NewContext(r.Context()) items, err := ctrl.svc.Items.Query(ctx, extractQuery(r)) if err != nil { log.Err(err).Msg("failed to get items") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, items) + return server.Respond(w, http.StatusOK, items) } } @@ -78,24 +78,22 @@ func (ctrl *V1Controller) HandleItemsGetAll() http.HandlerFunc { // @Success 200 {object} repo.ItemSummary // @Router /v1/items [POST] // @Security Bearer -func (ctrl *V1Controller) HandleItemsCreate() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleItemsCreate() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { createData := repo.ItemCreate{} if err := server.Decode(r, &createData); err != nil { log.Err(err).Msg("failed to decode request body") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } user := services.UseUserCtx(r.Context()) item, err := ctrl.svc.Items.Create(r.Context(), user.GroupID, createData) if err != nil { log.Err(err).Msg("failed to create item") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusCreated, item) + return server.Respond(w, http.StatusCreated, item) } } @@ -107,7 +105,7 @@ func (ctrl *V1Controller) HandleItemsCreate() http.HandlerFunc { // @Success 200 {object} repo.ItemOut // @Router /v1/items/{id} [GET] // @Security Bearer -func (ctrl *V1Controller) HandleItemGet() http.HandlerFunc { +func (ctrl *V1Controller) HandleItemGet() server.HandlerFunc { return ctrl.handleItemsGeneral() } @@ -119,7 +117,7 @@ func (ctrl *V1Controller) HandleItemGet() http.HandlerFunc { // @Success 204 // @Router /v1/items/{id} [DELETE] // @Security Bearer -func (ctrl *V1Controller) HandleItemDelete() http.HandlerFunc { +func (ctrl *V1Controller) HandleItemDelete() server.HandlerFunc { return ctrl.handleItemsGeneral() } @@ -132,16 +130,15 @@ func (ctrl *V1Controller) HandleItemDelete() http.HandlerFunc { // @Success 200 {object} repo.ItemOut // @Router /v1/items/{id} [PUT] // @Security Bearer -func (ctrl *V1Controller) HandleItemUpdate() http.HandlerFunc { +func (ctrl *V1Controller) HandleItemUpdate() server.HandlerFunc { return ctrl.handleItemsGeneral() } -func (ctrl *V1Controller) handleItemsGeneral() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) handleItemsGeneral() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { ctx := services.NewContext(r.Context()) ID, err := ctrl.routeID(w, r) if err != nil { - return } switch r.Method { @@ -149,37 +146,32 @@ func (ctrl *V1Controller) handleItemsGeneral() http.HandlerFunc { items, err := ctrl.svc.Items.GetOne(r.Context(), ctx.GID, ID) if err != nil { log.Err(err).Msg("failed to get item") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, items) - return + return server.Respond(w, http.StatusOK, items) case http.MethodDelete: err = ctrl.svc.Items.Delete(r.Context(), ctx.GID, ID) if err != nil { log.Err(err).Msg("failed to delete item") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) - return + return server.Respond(w, http.StatusNoContent, nil) case http.MethodPut: body := repo.ItemUpdate{} if err := server.Decode(r, &body); err != nil { log.Err(err).Msg("failed to decode request body") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } body.ID = ID result, err := ctrl.svc.Items.Update(r.Context(), ctx.GID, body) if err != nil { log.Err(err).Msg("failed to update item") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, result) + return server.Respond(w, http.StatusOK, result) } + return nil } } @@ -191,29 +183,26 @@ func (ctrl *V1Controller) handleItemsGeneral() http.HandlerFunc { // @Param csv formData file true "Image to upload" // @Router /v1/items/import [Post] // @Security Bearer -func (ctrl *V1Controller) HandleItemsImport() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleItemsImport() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { err := r.ParseMultipartForm(ctrl.maxUploadSize << 20) if err != nil { log.Err(err).Msg("failed to parse multipart form") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } file, _, err := r.FormFile("csv") if err != nil { log.Err(err).Msg("failed to get file from form") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } reader := csv.NewReader(file) data, err := reader.ReadAll() if err != nil { log.Err(err).Msg("failed to read csv") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } user := services.UseUserCtx(r.Context()) @@ -221,10 +210,9 @@ func (ctrl *V1Controller) HandleItemsImport() http.HandlerFunc { _, err = ctrl.svc.Items.CsvImport(r.Context(), user.GroupID, data) if err != nil { log.Err(err).Msg("failed to import items") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) + return server.Respond(w, http.StatusNoContent, nil) } } diff --git a/backend/app/api/handlers/v1/v1_ctrl_items_attachments.go b/backend/app/api/handlers/v1/v1_ctrl_items_attachments.go index d4e2981..436761d 100644 --- a/backend/app/api/handlers/v1/v1_ctrl_items_attachments.go +++ b/backend/app/api/handlers/v1/v1_ctrl_items_attachments.go @@ -10,6 +10,7 @@ import ( "github.com/hay-kot/homebox/backend/ent/attachment" "github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/pkgs/server" "github.com/rs/zerolog/log" ) @@ -32,13 +33,13 @@ type ( // @Failure 422 {object} []server.ValidationError // @Router /v1/items/{id}/attachments [POST] // @Security Bearer -func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleItemAttachmentCreate() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { err := r.ParseMultipartForm(ctrl.maxUploadSize << 20) if err != nil { log.Err(err).Msg("failed to parse multipart form") - server.RespondError(w, http.StatusBadRequest, errors.New("failed to parse multipart form")) - return + return validate.NewRequestError(errors.New("failed to parse multipart form"), http.StatusBadRequest) + } errs := make(server.ValidationErrors, 0) @@ -51,8 +52,7 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc { errs = errs.Append("file", "file is required") default: log.Err(err).Msg("failed to get file from form") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } } @@ -63,8 +63,7 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc { } if errs.HasErrors() { - server.Respond(w, http.StatusUnprocessableEntity, errs) - return + return server.Respond(w, http.StatusUnprocessableEntity, errs) } attachmentType := r.FormValue("type") @@ -74,7 +73,6 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc { id, err := ctrl.routeID(w, r) if err != nil { - return } ctx := services.NewContext(r.Context()) @@ -89,11 +87,10 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc { if err != nil { log.Err(err).Msg("failed to add attachment") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusCreated, item) + return server.Respond(w, http.StatusCreated, item) } } @@ -106,21 +103,21 @@ func (ctrl *V1Controller) HandleItemAttachmentCreate() http.HandlerFunc { // @Success 200 // @Router /v1/items/{id}/attachments/download [GET] // @Security Bearer -func (ctrl *V1Controller) HandleItemAttachmentDownload() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleItemAttachmentDownload() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { token := server.GetParam(r, "token", "") doc, err := ctrl.svc.Items.AttachmentPath(r.Context(), token) if err != nil { log.Err(err).Msg("failed to get attachment") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } w.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", doc.Title)) w.Header().Set("Content-Type", "application/octet-stream") http.ServeFile(w, r, doc.Path) + return nil } } @@ -133,7 +130,7 @@ func (ctrl *V1Controller) HandleItemAttachmentDownload() http.HandlerFunc { // @Success 200 {object} ItemAttachmentToken // @Router /v1/items/{id}/attachments/{attachment_id} [GET] // @Security Bearer -func (ctrl *V1Controller) HandleItemAttachmentToken() http.HandlerFunc { +func (ctrl *V1Controller) HandleItemAttachmentToken() server.HandlerFunc { return ctrl.handleItemAttachmentsHandler } @@ -145,7 +142,7 @@ func (ctrl *V1Controller) HandleItemAttachmentToken() http.HandlerFunc { // @Success 204 // @Router /v1/items/{id}/attachments/{attachment_id} [DELETE] // @Security Bearer -func (ctrl *V1Controller) HandleItemAttachmentDelete() http.HandlerFunc { +func (ctrl *V1Controller) HandleItemAttachmentDelete() server.HandlerFunc { return ctrl.handleItemAttachmentsHandler } @@ -158,21 +155,19 @@ func (ctrl *V1Controller) HandleItemAttachmentDelete() http.HandlerFunc { // @Success 200 {object} repo.ItemOut // @Router /v1/items/{id}/attachments/{attachment_id} [PUT] // @Security Bearer -func (ctrl *V1Controller) HandleItemAttachmentUpdate() http.HandlerFunc { +func (ctrl *V1Controller) HandleItemAttachmentUpdate() server.HandlerFunc { return ctrl.handleItemAttachmentsHandler } -func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r *http.Request) error { ID, err := ctrl.routeID(w, r) if err != nil { - return } attachmentId, err := uuid.Parse(chi.URLParam(r, "attachment_id")) if err != nil { log.Err(err).Msg("failed to parse attachment_id param") - server.RespondError(w, http.StatusBadRequest, err) - return + return validate.NewRequestError(err, http.StatusBadRequest) } ctx := services.NewContext(r.Context()) @@ -189,7 +184,7 @@ func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r Str("id", attachmentId.String()). Msg("failed to find attachment with id") - server.RespondError(w, http.StatusNotFound, err) + return validate.NewRequestError(err, http.StatusNotFound) case services.ErrFileNotFound: log.Err(err). @@ -197,27 +192,25 @@ func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r Msg("failed to find file path for attachment with id") log.Warn().Msg("attachment with no file path removed from database") - server.RespondError(w, http.StatusNotFound, err) + return validate.NewRequestError(err, http.StatusNotFound) default: log.Err(err).Msg("failed to get attachment") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } } - server.Respond(w, http.StatusOK, ItemAttachmentToken{Token: token}) + return server.Respond(w, http.StatusOK, ItemAttachmentToken{Token: token}) // Delete Attachment Handler case http.MethodDelete: err = ctrl.svc.Items.AttachmentDelete(r.Context(), ctx.GID, ID, attachmentId) if err != nil { log.Err(err).Msg("failed to delete attachment") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) + return server.Respond(w, http.StatusNoContent, nil) // Update Attachment Handler case http.MethodPut: @@ -225,18 +218,18 @@ func (ctrl *V1Controller) handleItemAttachmentsHandler(w http.ResponseWriter, r err = server.Decode(r, &attachment) if err != nil { log.Err(err).Msg("failed to decode attachment") - server.RespondError(w, http.StatusBadRequest, err) - return + return validate.NewRequestError(err, http.StatusBadRequest) } attachment.ID = attachmentId val, err := ctrl.svc.Items.AttachmentUpdate(ctx, ID, &attachment) if err != nil { log.Err(err).Msg("failed to delete attachment") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, val) + return server.Respond(w, http.StatusOK, val) } + + return nil } diff --git a/backend/app/api/handlers/v1/v1_ctrl_labels.go b/backend/app/api/handlers/v1/v1_ctrl_labels.go index 7d3c00d..c1a05d9 100644 --- a/backend/app/api/handlers/v1/v1_ctrl_labels.go +++ b/backend/app/api/handlers/v1/v1_ctrl_labels.go @@ -6,6 +6,7 @@ import ( "github.com/hay-kot/homebox/backend/ent" "github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/pkgs/server" "github.com/rs/zerolog/log" ) @@ -17,16 +18,15 @@ import ( // @Success 200 {object} server.Results{items=[]repo.LabelOut} // @Router /v1/labels [GET] // @Security Bearer -func (ctrl *V1Controller) HandleLabelsGetAll() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleLabelsGetAll() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { user := services.UseUserCtx(r.Context()) labels, err := ctrl.svc.Labels.GetAll(r.Context(), user.GroupID) if err != nil { log.Err(err).Msg("error getting labels") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, server.Results{Items: labels}) + return server.Respond(w, http.StatusOK, server.Results{Items: labels}) } } @@ -38,24 +38,22 @@ func (ctrl *V1Controller) HandleLabelsGetAll() http.HandlerFunc { // @Success 200 {object} repo.LabelSummary // @Router /v1/labels [POST] // @Security Bearer -func (ctrl *V1Controller) HandleLabelsCreate() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleLabelsCreate() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { createData := repo.LabelCreate{} if err := server.Decode(r, &createData); err != nil { log.Err(err).Msg("error decoding label create data") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } user := services.UseUserCtx(r.Context()) label, err := ctrl.svc.Labels.Create(r.Context(), user.GroupID, createData) if err != nil { log.Err(err).Msg("error creating label") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusCreated, label) + return server.Respond(w, http.StatusCreated, label) } } @@ -67,7 +65,7 @@ func (ctrl *V1Controller) HandleLabelsCreate() http.HandlerFunc { // @Success 204 // @Router /v1/labels/{id} [DELETE] // @Security Bearer -func (ctrl *V1Controller) HandleLabelDelete() http.HandlerFunc { +func (ctrl *V1Controller) HandleLabelDelete() server.HandlerFunc { return ctrl.handleLabelsGeneral() } @@ -79,7 +77,7 @@ func (ctrl *V1Controller) HandleLabelDelete() http.HandlerFunc { // @Success 200 {object} repo.LabelOut // @Router /v1/labels/{id} [GET] // @Security Bearer -func (ctrl *V1Controller) HandleLabelGet() http.HandlerFunc { +func (ctrl *V1Controller) HandleLabelGet() server.HandlerFunc { return ctrl.handleLabelsGeneral() } @@ -91,16 +89,15 @@ func (ctrl *V1Controller) HandleLabelGet() http.HandlerFunc { // @Success 200 {object} repo.LabelOut // @Router /v1/labels/{id} [PUT] // @Security Bearer -func (ctrl *V1Controller) HandleLabelUpdate() http.HandlerFunc { +func (ctrl *V1Controller) HandleLabelUpdate() server.HandlerFunc { return ctrl.handleLabelsGeneral() } -func (ctrl *V1Controller) handleLabelsGeneral() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) handleLabelsGeneral() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { ctx := services.NewContext(r.Context()) ID, err := ctrl.routeID(w, r) if err != nil { - return } switch r.Method { @@ -111,40 +108,37 @@ func (ctrl *V1Controller) handleLabelsGeneral() http.HandlerFunc { log.Err(err). Str("id", ID.String()). Msg("label not found") - server.RespondError(w, http.StatusNotFound, err) - return + return validate.NewRequestError(err, http.StatusNotFound) } log.Err(err).Msg("error getting label") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, labels) + return server.Respond(w, http.StatusOK, labels) case http.MethodDelete: err = ctrl.svc.Labels.Delete(r.Context(), ctx.GID, ID) if err != nil { log.Err(err).Msg("error deleting label") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) + return server.Respond(w, http.StatusNoContent, nil) case http.MethodPut: body := repo.LabelUpdate{} if err := server.Decode(r, &body); err != nil { log.Err(err).Msg("error decoding label update data") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } body.ID = ID result, err := ctrl.svc.Labels.Update(r.Context(), ctx.GID, body) if err != nil { log.Err(err).Msg("error updating label") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, result) + return server.Respond(w, http.StatusOK, result) } + + return nil } } diff --git a/backend/app/api/handlers/v1/v1_ctrl_locations.go b/backend/app/api/handlers/v1/v1_ctrl_locations.go index 7a9e525..6b7bed7 100644 --- a/backend/app/api/handlers/v1/v1_ctrl_locations.go +++ b/backend/app/api/handlers/v1/v1_ctrl_locations.go @@ -6,6 +6,7 @@ import ( "github.com/hay-kot/homebox/backend/ent" "github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/pkgs/server" "github.com/rs/zerolog/log" ) @@ -17,17 +18,16 @@ import ( // @Success 200 {object} server.Results{items=[]repo.LocationOutCount} // @Router /v1/locations [GET] // @Security Bearer -func (ctrl *V1Controller) HandleLocationGetAll() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleLocationGetAll() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { user := services.UseUserCtx(r.Context()) locations, err := ctrl.svc.Location.GetAll(r.Context(), user.GroupID) if err != nil { log.Err(err).Msg("failed to get locations") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, server.Results{Items: locations}) + return server.Respond(w, http.StatusOK, server.Results{Items: locations}) } } @@ -39,24 +39,22 @@ func (ctrl *V1Controller) HandleLocationGetAll() http.HandlerFunc { // @Success 200 {object} repo.LocationSummary // @Router /v1/locations [POST] // @Security Bearer -func (ctrl *V1Controller) HandleLocationCreate() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleLocationCreate() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { createData := repo.LocationCreate{} if err := server.Decode(r, &createData); err != nil { log.Err(err).Msg("failed to decode location create data") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } user := services.UseUserCtx(r.Context()) location, err := ctrl.svc.Location.Create(r.Context(), user.GroupID, createData) if err != nil { log.Err(err).Msg("failed to create location") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusCreated, location) + return server.Respond(w, http.StatusCreated, location) } } @@ -68,7 +66,7 @@ func (ctrl *V1Controller) HandleLocationCreate() http.HandlerFunc { // @Success 204 // @Router /v1/locations/{id} [DELETE] // @Security Bearer -func (ctrl *V1Controller) HandleLocationDelete() http.HandlerFunc { +func (ctrl *V1Controller) HandleLocationDelete() server.HandlerFunc { return ctrl.handleLocationGeneral() } @@ -80,7 +78,7 @@ func (ctrl *V1Controller) HandleLocationDelete() http.HandlerFunc { // @Success 200 {object} repo.LocationOut // @Router /v1/locations/{id} [GET] // @Security Bearer -func (ctrl *V1Controller) HandleLocationGet() http.HandlerFunc { +func (ctrl *V1Controller) HandleLocationGet() server.HandlerFunc { return ctrl.handleLocationGeneral() } @@ -93,16 +91,15 @@ func (ctrl *V1Controller) HandleLocationGet() http.HandlerFunc { // @Success 200 {object} repo.LocationOut // @Router /v1/locations/{id} [PUT] // @Security Bearer -func (ctrl *V1Controller) HandleLocationUpdate() http.HandlerFunc { +func (ctrl *V1Controller) HandleLocationUpdate() server.HandlerFunc { return ctrl.handleLocationGeneral() } -func (ctrl *V1Controller) handleLocationGeneral() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) handleLocationGeneral() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { ctx := services.NewContext(r.Context()) ID, err := ctrl.routeID(w, r) if err != nil { - return } switch r.Method { @@ -115,21 +112,18 @@ func (ctrl *V1Controller) handleLocationGeneral() http.HandlerFunc { if ent.IsNotFound(err) { l.Msg("location not found") - server.RespondError(w, http.StatusNotFound, err) - return + return validate.NewRequestError(err, http.StatusNotFound) } l.Msg("failed to get location") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, location) + return server.Respond(w, http.StatusOK, location) case http.MethodPut: body := repo.LocationUpdate{} if err := server.Decode(r, &body); err != nil { log.Err(err).Msg("failed to decode location update data") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } body.ID = ID @@ -137,18 +131,17 @@ func (ctrl *V1Controller) handleLocationGeneral() http.HandlerFunc { result, err := ctrl.svc.Location.Update(r.Context(), ctx.GID, body) if err != nil { log.Err(err).Msg("failed to update location") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, result) + return server.Respond(w, http.StatusOK, result) case http.MethodDelete: err = ctrl.svc.Location.Delete(r.Context(), ctx.GID, ID) if err != nil { log.Err(err).Msg("failed to delete location") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) + return server.Respond(w, http.StatusNoContent, nil) } + return nil } } diff --git a/backend/app/api/handlers/v1/v1_ctrl_user.go b/backend/app/api/handlers/v1/v1_ctrl_user.go index 3b8fd44..b7bf5ee 100644 --- a/backend/app/api/handlers/v1/v1_ctrl_user.go +++ b/backend/app/api/handlers/v1/v1_ctrl_user.go @@ -6,6 +6,7 @@ import ( "github.com/google/uuid" "github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/pkgs/server" "github.com/rs/zerolog/log" ) @@ -17,29 +18,26 @@ import ( // @Param payload body services.UserRegistration true "User Data" // @Success 204 // @Router /v1/users/register [Post] -func (ctrl *V1Controller) HandleUserRegistration() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleUserRegistration() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { regData := services.UserRegistration{} if err := server.Decode(r, ®Data); err != nil { log.Err(err).Msg("failed to decode user registration data") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } if !ctrl.allowRegistration && regData.GroupToken == "" { - server.RespondError(w, http.StatusForbidden, nil) - return + return validate.NewRequestError(nil, http.StatusForbidden) } _, err := ctrl.svc.User.RegisterUser(r.Context(), regData) if err != nil { log.Err(err).Msg("failed to register user") - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) + return server.Respond(w, http.StatusNoContent, nil) } } @@ -50,17 +48,16 @@ func (ctrl *V1Controller) HandleUserRegistration() http.HandlerFunc { // @Success 200 {object} server.Result{item=repo.UserOut} // @Router /v1/users/self [GET] // @Security Bearer -func (ctrl *V1Controller) HandleUserSelf() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleUserSelf() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { token := services.UseTokenCtx(r.Context()) usr, err := ctrl.svc.User.GetSelf(r.Context(), token) if usr.ID == uuid.Nil || err != nil { log.Err(err).Msg("failed to get user") - server.RespondServerError(w) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, server.Wrap(usr)) + return server.Respond(w, http.StatusOK, server.Wrap(usr)) } } @@ -72,24 +69,22 @@ func (ctrl *V1Controller) HandleUserSelf() http.HandlerFunc { // @Success 200 {object} server.Result{item=repo.UserUpdate} // @Router /v1/users/self [PUT] // @Security Bearer -func (ctrl *V1Controller) HandleUserSelfUpdate() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleUserSelfUpdate() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { updateData := repo.UserUpdate{} if err := server.Decode(r, &updateData); err != nil { log.Err(err).Msg("failed to decode user update data") - server.RespondError(w, http.StatusBadRequest, err) - return + return validate.NewRequestError(err, http.StatusBadRequest) } actor := services.UseUserCtx(r.Context()) newData, err := ctrl.svc.User.UpdateSelf(r.Context(), actor.ID, updateData) if err != nil { - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusOK, server.Wrap(newData)) + return server.Respond(w, http.StatusOK, server.Wrap(newData)) } } @@ -100,20 +95,18 @@ func (ctrl *V1Controller) HandleUserSelfUpdate() http.HandlerFunc { // @Success 204 // @Router /v1/users/self [DELETE] // @Security Bearer -func (ctrl *V1Controller) HandleUserSelfDelete() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleUserSelfDelete() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { if ctrl.isDemo { - server.RespondError(w, http.StatusForbidden, nil) - return + return validate.NewRequestError(nil, http.StatusForbidden) } actor := services.UseUserCtx(r.Context()) if err := ctrl.svc.User.DeleteSelf(r.Context(), actor.ID); err != nil { - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) + return server.Respond(w, http.StatusNoContent, nil) } } @@ -131,11 +124,10 @@ type ( // @Param payload body ChangePassword true "Password Payload" // @Router /v1/users/change-password [PUT] // @Security Bearer -func (ctrl *V1Controller) HandleUserSelfChangePassword() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { +func (ctrl *V1Controller) HandleUserSelfChangePassword() server.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) error { if ctrl.isDemo { - server.RespondError(w, http.StatusForbidden, nil) - return + return validate.NewRequestError(nil, http.StatusForbidden) } var cp ChangePassword @@ -148,10 +140,9 @@ func (ctrl *V1Controller) HandleUserSelfChangePassword() http.HandlerFunc { ok := ctrl.svc.User.ChangePassword(ctx, cp.Current, cp.New) if !ok { - server.RespondError(w, http.StatusInternalServerError, err) - return + return validate.NewRequestError(err, http.StatusInternalServerError) } - server.Respond(w, http.StatusNoContent, nil) + return server.Respond(w, http.StatusNoContent, nil) } } diff --git a/backend/app/api/logger.go b/backend/app/api/logger.go index 6756ffc..86085a5 100644 --- a/backend/app/api/logger.go +++ b/backend/app/api/logger.go @@ -15,7 +15,7 @@ func (a *app) setupLogger() { // Logger Init // zerolog.TimeFieldFormat = zerolog.TimeFormatUnix if a.conf.Log.Format != config.LogFormatJSON { - log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: os.Stderr}).With().Caller().Logger() } log.Level(getLevel(a.conf.Log.Level)) diff --git a/backend/app/api/main.go b/backend/app/api/main.go index 9cbff98..e48ae04 100644 --- a/backend/app/api/main.go +++ b/backend/app/api/main.go @@ -15,6 +15,7 @@ import ( "github.com/hay-kot/homebox/backend/internal/migrations" "github.com/hay-kot/homebox/backend/internal/repo" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/web/mid" "github.com/hay-kot/homebox/backend/pkgs/server" _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog/log" @@ -114,17 +115,25 @@ func run(cfg *config.Config) error { app.services = services.New(app.repos) // ========================================================================= - // Start Server + // Start Server\ + logger := log.With().Caller().Logger() + + mwLogger := mid.Logger(logger) + if app.conf.Mode == config.ModeDevelopment { + mwLogger = mid.SugarLogger(logger) + } + app.server = server.NewServer( server.WithHost(app.conf.Web.Host), server.WithPort(app.conf.Web.Port), + server.WithMiddleware( + mwLogger, + mid.Errors(logger), + mid.Panic(app.conf.Mode == config.ModeDevelopment), + ), ) - routes := app.newRouter(app.repos) - - if app.conf.Mode != config.ModeDevelopment { - app.logRoutes(routes) - } + app.mountRoutes(app.repos) log.Info().Msgf("Starting HTTP Server on %s:%s", app.server.Host, app.server.Port) @@ -163,5 +172,5 @@ func run(cfg *config.Config) error { }() } - return app.server.Start(routes) + return app.server.Start() } diff --git a/backend/app/api/middleware.go b/backend/app/api/middleware.go index 7db65aa..68bfde1 100644 --- a/backend/app/api/middleware.go +++ b/backend/app/api/middleware.go @@ -1,143 +1,34 @@ package main import ( - "fmt" + "errors" "net/http" "strings" - "time" - "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" - "github.com/hay-kot/homebox/backend/internal/config" "github.com/hay-kot/homebox/backend/internal/services" + "github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/pkgs/server" - "github.com/rs/zerolog/log" ) -func (a *app) setGlobalMiddleware(r *chi.Mux) { - // ========================================================================= - // Middleware - r.Use(middleware.RequestID) - r.Use(middleware.RealIP) - r.Use(mwStripTrailingSlash) - - // Use struct logger in production for requests, but use - // pretty console logger in development. - if a.conf.Mode == config.ModeDevelopment { - r.Use(a.mwSummaryLogger) - } else { - r.Use(a.mwStructLogger) - } - r.Use(middleware.Recoverer) - - // Set a timeout value on the request context (ctx), that will signal - // through ctx.Done() that the request has timed out and further - // processing should be stopped. - r.Use(middleware.Timeout(60 * time.Second)) -} - // mwAuthToken is a middleware that will check the database for a stateful token // and attach it to the request context with the user, or return a 401 if it doesn't exist. -func (a *app) mwAuthToken(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func (a *app) mwAuthToken(next server.Handler) server.Handler { + return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { requestToken := r.Header.Get("Authorization") if requestToken == "" { - server.RespondUnauthorized(w) - return + return validate.NewRequestError(errors.New("Authorization header is required"), http.StatusUnauthorized) } requestToken = strings.TrimPrefix(requestToken, "Bearer ") usr, err := a.services.User.GetSelf(r.Context(), requestToken) // Check the database for the token - if err != nil { - server.RespondUnauthorized(w) - return + return validate.NewRequestError(errors.New("Authorization header is required"), http.StatusUnauthorized) } r = r.WithContext(services.SetUserCtx(r.Context(), &usr, requestToken)) - - next.ServeHTTP(w, r) - }) -} - -// mqStripTrailingSlash is a middleware that will strip trailing slashes from the request path. -func mwStripTrailingSlash(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - r.URL.Path = strings.TrimSuffix(r.URL.Path, "/") - next.ServeHTTP(w, r) - }) -} - -type StatusRecorder struct { - http.ResponseWriter - Status int -} - -func (r *StatusRecorder) WriteHeader(status int) { - r.Status = status - r.ResponseWriter.WriteHeader(status) -} - -func (a *app) mwStructLogger(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - record := &StatusRecorder{ResponseWriter: w, Status: http.StatusOK} - next.ServeHTTP(record, r) - - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - - url := fmt.Sprintf("%s://%s%s %s", scheme, r.Host, r.RequestURI, r.Proto) - - log.Info(). - Str("id", middleware.GetReqID(r.Context())). - Str("method", r.Method). - Str("remote_addr", r.RemoteAddr). - Int("status", record.Status). - Msg(url) - }) -} - -func (a *app) mwSummaryLogger(next http.Handler) http.Handler { - bold := func(s string) string { return "\033[1m" + s + "\033[0m" } - orange := func(s string) string { return "\033[33m" + s + "\033[0m" } - aqua := func(s string) string { return "\033[36m" + s + "\033[0m" } - red := func(s string) string { return "\033[31m" + s + "\033[0m" } - green := func(s string) string { return "\033[32m" + s + "\033[0m" } - - fmtCode := func(code int) string { - switch { - case code >= 500: - return red(fmt.Sprintf("%d", code)) - case code >= 400: - return orange(fmt.Sprintf("%d", code)) - case code >= 300: - return aqua(fmt.Sprintf("%d", code)) - default: - return green(fmt.Sprintf("%d", code)) - } - } - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - record := &StatusRecorder{ResponseWriter: w, Status: http.StatusOK} - next.ServeHTTP(record, r) // Blocks until the next handler returns. - - scheme := "http" - if r.TLS != nil { - scheme = "https" - } - - url := fmt.Sprintf("%s://%s%s %s", scheme, r.Host, r.RequestURI, r.Proto) - - log.Info(). - Msgf("%s %s %s", - bold(orange(""+r.Method+"")), - aqua(url), - bold(fmtCode(record.Status)), - ) + return next.ServeHTTP(w, r) }) } diff --git a/backend/app/api/routes.go b/backend/app/api/routes.go index f5d887a..4813e43 100644 --- a/backend/app/api/routes.go +++ b/backend/app/api/routes.go @@ -15,6 +15,7 @@ import ( v1 "github.com/hay-kot/homebox/backend/app/api/handlers/v1" _ "github.com/hay-kot/homebox/backend/app/api/static/docs" "github.com/hay-kot/homebox/backend/internal/repo" + "github.com/hay-kot/homebox/backend/pkgs/server" httpSwagger "github.com/swaggo/http-swagger" // http-swagger middleware ) @@ -35,80 +36,76 @@ func (a *app) debugRouter() *http.ServeMux { } // registerRoutes registers all the routes for the API -func (a *app) newRouter(repos *repo.AllRepos) *chi.Mux { +func (a *app) mountRoutes(repos *repo.AllRepos) { registerMimes() - r := chi.NewRouter() - a.setGlobalMiddleware(r) - - r.Get("/swagger/*", httpSwagger.Handler( + a.server.Get("/swagger/*", server.ToHandler(httpSwagger.Handler( httpSwagger.URL(fmt.Sprintf("%s://%s/swagger/doc.json", a.conf.Swagger.Scheme, a.conf.Swagger.Host)), - )) + ))) // ========================================================================= // API Version 1 v1Base := v1.BaseUrlFunc(prefix) - v1Ctrl := v1.NewControllerV1(a.services, + + v1Ctrl := v1.NewControllerV1( + a.services, v1.WithMaxUploadSize(a.conf.Web.MaxUploadSize), v1.WithRegistration(a.conf.AllowRegistration), v1.WithDemoStatus(a.conf.Demo), // Disable Password Change in Demo Mode ) - r.Get(v1Base("/status"), v1Ctrl.HandleBase(func() bool { return true }, v1.Build{ + + a.server.Get(v1Base("/status"), v1Ctrl.HandleBase(func() bool { return true }, v1.Build{ Version: version, Commit: commit, BuildTime: buildTime, })) - r.Post(v1Base("/users/register"), v1Ctrl.HandleUserRegistration()) - r.Post(v1Base("/users/login"), v1Ctrl.HandleAuthLogin()) + a.server.Post(v1Base("/users/register"), v1Ctrl.HandleUserRegistration()) + a.server.Post(v1Base("/users/login"), v1Ctrl.HandleAuthLogin()) // Attachment download URl needs a `token` query param to be passed in the request. // and also needs to be outside of the `auth` middleware. - r.Get(v1Base("/items/{id}/attachments/download"), v1Ctrl.HandleItemAttachmentDownload()) + a.server.Get(v1Base("/items/{id}/attachments/download"), v1Ctrl.HandleItemAttachmentDownload()) - r.Group(func(r chi.Router) { - r.Use(a.mwAuthToken) - r.Get(v1Base("/users/self"), v1Ctrl.HandleUserSelf()) - r.Put(v1Base("/users/self"), v1Ctrl.HandleUserSelfUpdate()) - r.Delete(v1Base("/users/self"), v1Ctrl.HandleUserSelfDelete()) - r.Post(v1Base("/users/logout"), v1Ctrl.HandleAuthLogout()) - r.Get(v1Base("/users/refresh"), v1Ctrl.HandleAuthRefresh()) - r.Put(v1Base("/users/self/change-password"), v1Ctrl.HandleUserSelfChangePassword()) + a.server.Get(v1Base("/users/self"), v1Ctrl.HandleUserSelf(), a.mwAuthToken) + a.server.Put(v1Base("/users/self"), v1Ctrl.HandleUserSelfUpdate(), a.mwAuthToken) + a.server.Delete(v1Base("/users/self"), v1Ctrl.HandleUserSelfDelete(), a.mwAuthToken) + a.server.Post(v1Base("/users/logout"), v1Ctrl.HandleAuthLogout(), a.mwAuthToken) + a.server.Get(v1Base("/users/refresh"), v1Ctrl.HandleAuthRefresh(), a.mwAuthToken) + a.server.Put(v1Base("/users/self/change-password"), v1Ctrl.HandleUserSelfChangePassword(), a.mwAuthToken) - r.Post(v1Base("/groups/invitations"), v1Ctrl.HandleGroupInvitationsCreate()) + a.server.Post(v1Base("/groups/invitations"), v1Ctrl.HandleGroupInvitationsCreate(), a.mwAuthToken) - // TODO: I don't like /groups being the URL for users - r.Get(v1Base("/groups"), v1Ctrl.HandleGroupGet()) - r.Put(v1Base("/groups"), v1Ctrl.HandleGroupUpdate()) + // TODO: I don't like /groups being the URL for users + a.server.Get(v1Base("/groups"), v1Ctrl.HandleGroupGet(), a.mwAuthToken) + a.server.Put(v1Base("/groups"), v1Ctrl.HandleGroupUpdate(), a.mwAuthToken) - r.Get(v1Base("/locations"), v1Ctrl.HandleLocationGetAll()) - r.Post(v1Base("/locations"), v1Ctrl.HandleLocationCreate()) - r.Get(v1Base("/locations/{id}"), v1Ctrl.HandleLocationGet()) - r.Put(v1Base("/locations/{id}"), v1Ctrl.HandleLocationUpdate()) - r.Delete(v1Base("/locations/{id}"), v1Ctrl.HandleLocationDelete()) + a.server.Get(v1Base("/locations"), v1Ctrl.HandleLocationGetAll(), a.mwAuthToken) + a.server.Post(v1Base("/locations"), v1Ctrl.HandleLocationCreate(), a.mwAuthToken) + a.server.Get(v1Base("/locations/{id}"), v1Ctrl.HandleLocationGet(), a.mwAuthToken) + a.server.Put(v1Base("/locations/{id}"), v1Ctrl.HandleLocationUpdate(), a.mwAuthToken) + a.server.Delete(v1Base("/locations/{id}"), v1Ctrl.HandleLocationDelete(), a.mwAuthToken) - r.Get(v1Base("/labels"), v1Ctrl.HandleLabelsGetAll()) - r.Post(v1Base("/labels"), v1Ctrl.HandleLabelsCreate()) - r.Get(v1Base("/labels/{id}"), v1Ctrl.HandleLabelGet()) - r.Put(v1Base("/labels/{id}"), v1Ctrl.HandleLabelUpdate()) - r.Delete(v1Base("/labels/{id}"), v1Ctrl.HandleLabelDelete()) + a.server.Get(v1Base("/labels"), v1Ctrl.HandleLabelsGetAll(), a.mwAuthToken) + a.server.Post(v1Base("/labels"), v1Ctrl.HandleLabelsCreate(), a.mwAuthToken) + a.server.Get(v1Base("/labels/{id}"), v1Ctrl.HandleLabelGet(), a.mwAuthToken) + a.server.Put(v1Base("/labels/{id}"), v1Ctrl.HandleLabelUpdate(), a.mwAuthToken) + a.server.Delete(v1Base("/labels/{id}"), v1Ctrl.HandleLabelDelete(), a.mwAuthToken) - r.Get(v1Base("/items"), v1Ctrl.HandleItemsGetAll()) - r.Post(v1Base("/items/import"), v1Ctrl.HandleItemsImport()) - r.Post(v1Base("/items"), v1Ctrl.HandleItemsCreate()) - r.Get(v1Base("/items/{id}"), v1Ctrl.HandleItemGet()) - r.Put(v1Base("/items/{id}"), v1Ctrl.HandleItemUpdate()) - r.Delete(v1Base("/items/{id}"), v1Ctrl.HandleItemDelete()) + a.server.Get(v1Base("/items"), v1Ctrl.HandleItemsGetAll(), a.mwAuthToken) + a.server.Post(v1Base("/items/import"), v1Ctrl.HandleItemsImport(), a.mwAuthToken) + a.server.Post(v1Base("/items"), v1Ctrl.HandleItemsCreate(), a.mwAuthToken) + a.server.Get(v1Base("/items/{id}"), v1Ctrl.HandleItemGet(), a.mwAuthToken) + a.server.Put(v1Base("/items/{id}"), v1Ctrl.HandleItemUpdate(), a.mwAuthToken) + a.server.Delete(v1Base("/items/{id}"), v1Ctrl.HandleItemDelete(), a.mwAuthToken) - r.Post(v1Base("/items/{id}/attachments"), v1Ctrl.HandleItemAttachmentCreate()) - r.Get(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentToken()) - r.Put(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentUpdate()) - r.Delete(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentDelete()) - }) + a.server.Post(v1Base("/items/{id}/attachments"), v1Ctrl.HandleItemAttachmentCreate(), a.mwAuthToken) + a.server.Get(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentToken(), a.mwAuthToken) + a.server.Put(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentUpdate(), a.mwAuthToken) + a.server.Delete(v1Base("/items/{id}/attachments/{attachment_id}"), v1Ctrl.HandleItemAttachmentDelete(), a.mwAuthToken) - r.NotFound(notFoundHandler()) - return r + a.server.NotFound(notFoundHandler()) } // logRoutes logs the routes of the server that are registered within Server.registerRoutes(). This is useful for debugging. @@ -146,7 +143,7 @@ func registerMimes() { // notFoundHandler perform the main logic around handling the internal SPA embed and ensuring that // the client side routing is handled correctly. -func notFoundHandler() http.HandlerFunc { +func notFoundHandler() server.HandlerFunc { tryRead := func(fs embed.FS, prefix, requestedPath string, w http.ResponseWriter) error { f, err := fs.Open(path.Join(prefix, requestedPath)) if err != nil { @@ -165,14 +162,16 @@ func notFoundHandler() http.HandlerFunc { return err } - return func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) error { err := tryRead(public, "static/public", r.URL.Path, w) - if err == nil { - return - } - err = tryRead(public, "static/public", "index.html", w) if err != nil { - panic(err) + // Fallback to the index.html file. + // should succeed in all cases. + err = tryRead(public, "static/public", "index.html", w) + if err != nil { + return err + } } + return nil } } diff --git a/backend/internal/sys/validate/errors.go b/backend/internal/sys/validate/errors.go new file mode 100644 index 0000000..b4fd58d --- /dev/null +++ b/backend/internal/sys/validate/errors.go @@ -0,0 +1,82 @@ +package validate + +import ( + "encoding/json" + "errors" +) + +func UnauthorizedError() error { + return errors.New("unauthorized") +} + +// ErrInvalidID occurs when an ID is not in a valid form. +var ErrInvalidID = errors.New("ID is not in its proper form") + +// ErrorResponse is the form used for API responses from failures in the API. +type ErrorResponse struct { + Error string `json:"error"` + Fields string `json:"fields,omitempty"` +} + +// RequestError is used to pass an error during the request through the +// application with web specific context. +type RequestError struct { + Err error + Status int + Fields error +} + +// NewRequestError wraps a provided error with an HTTP status code. This +// function should be used when handlers encounter expected errors. +func NewRequestError(err error, status int) error { + return &RequestError{err, status, nil} +} + +func (err *RequestError) Error() string { + return err.Err.Error() +} + +// IsRequestError checks if an error of type RequestError exists. +func IsRequestError(err error) bool { + var re *RequestError + return errors.As(err, &re) +} + +// FieldError is used to indicate an error with a specific request field. +type FieldError struct { + Field string `json:"field"` + Error string `json:"error"` +} + +// FieldErrors represents a collection of field errors. +type FieldErrors []FieldError + +// Error implments the error interface. +func (fe FieldErrors) Error() string { + d, err := json.Marshal(fe) + if err != nil { + return err.Error() + } + return string(d) +} + +func NewFieldErrors(errs ...FieldError) FieldErrors { + return errs +} + +func IsFieldError(err error) bool { + v := FieldErrors{} + return errors.As(err, &v) +} + +// Cause iterates through all the wrapped errors until the root +// error value is reached. +func Cause(err error) error { + root := err + for { + if err = errors.Unwrap(root); err == nil { + return root + } + root = err + } +} diff --git a/backend/internal/web/mid/errors.go b/backend/internal/web/mid/errors.go new file mode 100644 index 0000000..fd761d3 --- /dev/null +++ b/backend/internal/web/mid/errors.go @@ -0,0 +1,66 @@ +package mid + +import ( + "errors" + "net/http" + + "github.com/hay-kot/homebox/backend/ent" + "github.com/hay-kot/homebox/backend/internal/sys/validate" + "github.com/hay-kot/homebox/backend/pkgs/server" + "github.com/rs/zerolog" +) + +func Errors(log zerolog.Logger) server.Middleware { + return func(h server.Handler) server.Handler { + return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + err := h.ServeHTTP(w, r) + + if err != nil { + var resp server.ErrorResponse + var code int + + log.Err(err). + Str("trace_id", "TODO"). + Msg("ERROR occurred") + + switch { + case errors.Is(err, validate.ErrInvalidID): + code = http.StatusBadRequest + resp = server.ErrorResponse{ + Error: "invalid id parameter", + } + case validate.IsFieldError(err): + fieldErrors := err.(validate.FieldErrors) + resp.Error = "Validation Error" + resp.Fields = map[string]string{} + + for _, fieldError := range fieldErrors { + resp.Fields[fieldError.Field] = fieldError.Error + } + case validate.IsRequestError(err): + requestError := err.(*validate.RequestError) + resp.Error = requestError.Error() + code = requestError.Status + case ent.IsNotFound(err): + resp.Error = "Not Found" + code = http.StatusNotFound + default: + resp.Error = "Unknown Error" + code = http.StatusInternalServerError + + } + + if err := server.Respond(w, code, resp); err != nil { + return err + } + + // If Showdown error, return error + if server.IsShutdownError(err) { + return err + } + } + + return nil + }) + } +} diff --git a/backend/internal/web/mid/logger.go b/backend/internal/web/mid/logger.go new file mode 100644 index 0000000..b36f95c --- /dev/null +++ b/backend/internal/web/mid/logger.go @@ -0,0 +1,94 @@ +package mid + +import ( + "fmt" + "net/http" + + "github.com/hay-kot/homebox/backend/pkgs/server" + "github.com/rs/zerolog" +) + +type StatusRecorder struct { + http.ResponseWriter + Status int +} + +func (r *StatusRecorder) WriteHeader(status int) { + r.Status = status + r.ResponseWriter.WriteHeader(status) +} + +func Logger(log zerolog.Logger) server.Middleware { + return func(next server.Handler) server.Handler { + return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + + log.Info(). + Str("trace_id", "TODO"). + Str("method", r.Method). + Str("path", r.URL.Path). + Str("remove_address", r.RemoteAddr). + Msg("request started") + + record := &StatusRecorder{ResponseWriter: w, Status: http.StatusOK} + + err := next.ServeHTTP(record, r) + + log.Info(). + Str("trave_id", "TODO"). + Str("method", r.Method). + Str("url", r.URL.Path). + Str("remote_address", r.RemoteAddr). + Int("status_code", record.Status). + Msg("request completed") + + return err + + }) + } +} + +func SugarLogger(log zerolog.Logger) server.Middleware { + orange := func(s string) string { return "\033[33m" + s + "\033[0m" } + aqua := func(s string) string { return "\033[36m" + s + "\033[0m" } + red := func(s string) string { return "\033[31m" + s + "\033[0m" } + green := func(s string) string { return "\033[32m" + s + "\033[0m" } + + fmtCode := func(code int) string { + switch { + case code >= 500: + return red(fmt.Sprintf("%d", code)) + case code >= 400: + return orange(fmt.Sprintf("%d", code)) + case code >= 300: + return aqua(fmt.Sprintf("%d", code)) + default: + return green(fmt.Sprintf("%d", code)) + } + } + bold := func(s string) string { return "\033[1m" + s + "\033[0m" } + + return func(next server.Handler) server.Handler { + return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + + record := &StatusRecorder{ResponseWriter: w, Status: http.StatusOK} + + err := next.ServeHTTP(record, r) // Blocks until the next handler returns. + + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + url := fmt.Sprintf("%s://%s%s %s", scheme, r.Host, r.RequestURI, r.Proto) + + log.Info(). + Msgf("%s %s %s", + bold(fmtCode(record.Status)), + bold(orange(""+r.Method+"")), + aqua(url), + ) + + return err + }) + } +} diff --git a/backend/internal/web/mid/panic.go b/backend/internal/web/mid/panic.go new file mode 100644 index 0000000..9879bb8 --- /dev/null +++ b/backend/internal/web/mid/panic.go @@ -0,0 +1,33 @@ +package mid + +import ( + "fmt" + "net/http" + "runtime/debug" + + "github.com/hay-kot/homebox/backend/pkgs/server" +) + +// Panic is a middleware that recovers from panics anywhere in the chain and wraps the error. +// and returns it up the middleware chain. +func Panic(develop bool) server.Middleware { + return func(h server.Handler) server.Handler { + return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (err error) { + defer func() { + if rec := recover(); rec != nil { + trace := debug.Stack() + + if develop { + err = fmt.Errorf("PANIC [%v]", rec) + fmt.Printf("%s", string(trace)) + } else { + err = fmt.Errorf("PANIC [%v] TRACE[%s]", rec, string(trace)) + } + + } + }() + + return h.ServeHTTP(w, r) + }) + } +} diff --git a/backend/pkgs/server/errors.go b/backend/pkgs/server/errors.go new file mode 100644 index 0000000..5b1d60b --- /dev/null +++ b/backend/pkgs/server/errors.go @@ -0,0 +1,23 @@ +package server + +import "errors" + +type shutdownError struct { + message string +} + +func (e *shutdownError) Error() string { + return e.message +} + +// ShutdownError returns an error that indicates that the server has lost +// integrity and should be shut down. +func ShutdownError(message string) error { + return &shutdownError{message} +} + +// IsShutdownError returns true if the error is a shutdown error. +func IsShutdownError(err error) bool { + var e *shutdownError + return errors.As(err, &e) +} diff --git a/backend/pkgs/server/handler.go b/backend/pkgs/server/handler.go new file mode 100644 index 0000000..76ae131 --- /dev/null +++ b/backend/pkgs/server/handler.go @@ -0,0 +1,25 @@ +package server + +import ( + "net/http" +) + +type HandlerFunc func(w http.ResponseWriter, r *http.Request) error + +func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error { + return f(w, r) +} + +type Handler interface { + ServeHTTP(http.ResponseWriter, *http.Request) error +} + +// ToHandler converts a function to a customer implementation of the Handler interface. +// that returns an error. This wrapper around the handler function and simply +// returns the nil in all cases +func ToHandler(handler http.Handler) Handler { + return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + handler.ServeHTTP(w, r) + return nil + }) +} diff --git a/backend/pkgs/server/middleware.go b/backend/pkgs/server/middleware.go new file mode 100644 index 0000000..acbe869 --- /dev/null +++ b/backend/pkgs/server/middleware.go @@ -0,0 +1,36 @@ +package server + +import ( + "net/http" + "strings" +) + +type Middleware func(Handler) Handler + +// wrapMiddleware creates a new handler by wrapping middleware around a final +// handler. The middlewares' Handlers will be executed by requests in the order +// they are provided. +func wrapMiddleware(mw []Middleware, handler Handler) Handler { + + // Loop backwards through the middleware invoking each one. Replace the + // handler with the new wrapped handler. Looping backwards ensures that the + // first middleware of the slice is the first to be executed by requests. + for i := len(mw) - 1; i >= 0; i-- { + h := mw[i] + if h != nil { + handler = h(handler) + } + } + + return handler +} + +// StripTrailingSlash is a middleware that will strip trailing slashes from the request path. +func StripTrailingSlash() Middleware { + return func(h Handler) Handler { + return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + r.URL.Path = strings.TrimSuffix(r.URL.Path, "/") + return h.ServeHTTP(w, r) + }) + } +} diff --git a/backend/pkgs/server/mux.go b/backend/pkgs/server/mux.go new file mode 100644 index 0000000..26deab0 --- /dev/null +++ b/backend/pkgs/server/mux.go @@ -0,0 +1,98 @@ +package server + +import ( + "context" + "net/http" +) + +type vkey int + +const ( + // Key is the key for the server in the request context. + key vkey = 1 +) + +type Values struct { + TraceID string +} + +func (s *Server) toHttpHandler(handler Handler, mw ...Middleware) http.HandlerFunc { + handler = wrapMiddleware(mw, handler) + + handler = wrapMiddleware(s.mw, handler) + + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Add the trace ID to the context + ctx = context.WithValue(ctx, key, Values{ + // TODO: Initialize a new trace ID + TraceID: "00000000-0000-0000-0000-000000000000", + }) + + err := handler.ServeHTTP(w, r.WithContext(ctx)) + + if err != nil { + if IsShutdownError(err) { + s.Shutdown("SIGTERM") + } + } + } +} + +func (s *Server) handle(method, pattern string, handler Handler, mw ...Middleware) { + h := s.toHttpHandler(handler, mw...) + + switch method { + case http.MethodGet: + s.mux.Get(pattern, h) + case http.MethodPost: + s.mux.Post(pattern, h) + case http.MethodPut: + s.mux.Put(pattern, h) + case http.MethodDelete: + s.mux.Delete(pattern, h) + case http.MethodPatch: + s.mux.Patch(pattern, h) + case http.MethodHead: + s.mux.Head(pattern, h) + case http.MethodOptions: + s.mux.Options(pattern, h) + } +} + +func (s *Server) Handler(pattern string, handler Handler, mw ...Middleware) { + s.mux.Handle(pattern, s.toHttpHandler(handler, mw...)) +} + +func (s *Server) Get(pattern string, handler Handler, mw ...Middleware) { + s.handle(http.MethodGet, pattern, handler, mw...) +} + +func (s *Server) Post(pattern string, handler Handler, mw ...Middleware) { + s.handle(http.MethodPost, pattern, handler, mw...) +} + +func (s *Server) Put(pattern string, handler Handler, mw ...Middleware) { + s.handle(http.MethodPut, pattern, handler, mw...) +} + +func (s *Server) Delete(pattern string, handler Handler, mw ...Middleware) { + s.handle(http.MethodDelete, pattern, handler, mw...) +} + +func (s *Server) Patch(pattern string, handler Handler, mw ...Middleware) { + s.handle(http.MethodPatch, pattern, handler, mw...) +} + +func (s *Server) Head(pattern string, handler Handler, mw ...Middleware) { + s.handle(http.MethodHead, pattern, handler, mw...) +} + +func (s *Server) Options(pattern string, handler Handler, mw ...Middleware) { + s.handle(http.MethodOptions, pattern, handler, mw...) +} + +func (s *Server) NotFound(handler Handler) { + s.mux.NotFound(s.toHttpHandler(handler)) +} diff --git a/backend/pkgs/server/response.go b/backend/pkgs/server/response.go index cf14dd2..7d5880e 100644 --- a/backend/pkgs/server/response.go +++ b/backend/pkgs/server/response.go @@ -2,16 +2,20 @@ package server import ( "encoding/json" - "errors" "net/http" ) +type ErrorResponse struct { + Error string `json:"error"` + Fields map[string]string `json:"fields,omitempty"` +} + // Respond converts a Go value to JSON and sends it to the client. // Adapted from https://github.com/ardanlabs/service/tree/master/foundation/web -func Respond(w http.ResponseWriter, statusCode int, data interface{}) { +func Respond(w http.ResponseWriter, statusCode int, data interface{}) error { if statusCode == http.StatusNoContent { w.WriteHeader(statusCode) - return + return nil } // Convert the response value to JSON. @@ -28,31 +32,8 @@ func Respond(w http.ResponseWriter, statusCode int, data interface{}) { // Send the result back to the client. if _, err := w.Write(jsonData); err != nil { - panic(err) + return err } -} -// ResponseError is a helper function that sends a JSON response of an error message -func RespondError(w http.ResponseWriter, statusCode int, err error) { - eb := ErrorBuilder{} - eb.AddError(err) - eb.Respond(w, statusCode) -} - -// RespondServerError is a wrapper around RespondError that sends a 500 internal server error. Useful for -// Sending generic errors when everything went wrong. -func RespondServerError(w http.ResponseWriter) { - RespondError(w, http.StatusInternalServerError, errors.New("internal server error")) -} - -// RespondNotFound is a helper utility for responding with a generic -// "unauthorized" error. -func RespondUnauthorized(w http.ResponseWriter) { - RespondError(w, http.StatusUnauthorized, errors.New("unauthorized")) -} - -// RespondForbidden is a helper utility for responding with a generic -// "forbidden" error. -func RespondForbidden(w http.ResponseWriter) { - RespondError(w, http.StatusForbidden, errors.New("forbidden")) + return nil } diff --git a/backend/pkgs/server/response_test.go b/backend/pkgs/server/response_test.go index a6446d3..a438715 100644 --- a/backend/pkgs/server/response_test.go +++ b/backend/pkgs/server/response_test.go @@ -1,7 +1,6 @@ package server import ( - "errors" "net/http" "net/http/httptest" "testing" @@ -38,41 +37,3 @@ func Test_Respond_JSON(t *testing.T) { assert.Equal(t, "application/json", recorder.Header().Get("Content-Type")) } - -func Test_RespondError(t *testing.T) { - recorder := httptest.NewRecorder() - var customError = errors.New("custom error") - - RespondError(recorder, http.StatusBadRequest, customError) - - assert.Equal(t, http.StatusBadRequest, recorder.Code) - assert.JSONEq(t, recorder.Body.String(), `{"details":["custom error"], "message":"Bad Request", "error":true}`) - -} -func Test_RespondInternalServerError(t *testing.T) { - recorder := httptest.NewRecorder() - - RespondServerError(recorder) - - assert.Equal(t, http.StatusInternalServerError, recorder.Code) - assert.JSONEq(t, recorder.Body.String(), `{"details":["internal server error"], "message":"Internal Server Error", "error":true}`) - -} -func Test_RespondUnauthorized(t *testing.T) { - recorder := httptest.NewRecorder() - - RespondUnauthorized(recorder) - - assert.Equal(t, http.StatusUnauthorized, recorder.Code) - assert.JSONEq(t, recorder.Body.String(), `{"details":["unauthorized"], "message":"Unauthorized", "error":true}`) - -} -func Test_RespondForbidden(t *testing.T) { - recorder := httptest.NewRecorder() - - RespondForbidden(recorder) - - assert.Equal(t, http.StatusForbidden, recorder.Code) - assert.JSONEq(t, recorder.Body.String(), `{"details":["forbidden"], "message":"Forbidden", "error":true}`) - -} diff --git a/backend/pkgs/server/server.go b/backend/pkgs/server/server.go index 1cc7cf1..921c576 100644 --- a/backend/pkgs/server/server.go +++ b/backend/pkgs/server/server.go @@ -10,6 +10,8 @@ import ( "sync" "syscall" "time" + + "github.com/go-chi/chi/v5" ) var ( @@ -22,7 +24,11 @@ type Server struct { Port string Worker Worker - wg sync.WaitGroup + wg sync.WaitGroup + mux *chi.Mux + + // mw is the global middleware chain for the server. + mw []Middleware started bool activeServer *http.Server @@ -36,6 +42,7 @@ func NewServer(opts ...Option) *Server { s := &Server{ Host: "localhost", Port: "8080", + mux: chi.NewRouter(), Worker: NewSimpleWorker(), idleTimeout: 30 * time.Second, readTimeout: 10 * time.Second, @@ -75,14 +82,14 @@ func (s *Server) Shutdown(sig string) error { } -func (s *Server) Start(router http.Handler) error { +func (s *Server) Start() error { if s.started { return ErrServerAlreadyStarted } s.activeServer = &http.Server{ Addr: s.Host + ":" + s.Port, - Handler: router, + Handler: s.mux, IdleTimeout: s.idleTimeout, ReadTimeout: s.readTimeout, WriteTimeout: s.writeTimeout, diff --git a/backend/pkgs/server/server_options.go b/backend/pkgs/server/server_options.go index 1029b1f..93b7781 100644 --- a/backend/pkgs/server/server_options.go +++ b/backend/pkgs/server/server_options.go @@ -4,6 +4,13 @@ import "time" type Option = func(s *Server) error +func WithMiddleware(mw ...Middleware) Option { + return func(s *Server) error { + s.mw = append(s.mw, mw...) + return nil + } +} + func WithWorker(w Worker) Option { return func(s *Server) error { s.Worker = w diff --git a/backend/pkgs/server/server_test.go b/backend/pkgs/server/server_test.go index 1d8b105..5c07a72 100644 --- a/backend/pkgs/server/server_test.go +++ b/backend/pkgs/server/server_test.go @@ -12,8 +12,12 @@ import ( func testServer(t *testing.T, r http.Handler) *Server { svr := NewServer(WithHost("127.0.0.1"), WithPort("19245")) + if r != nil { + svr.mux.Handle("/", r) + } + go func() { - err := svr.Start(r) + err := svr.Start() assert.NoError(t, err) }() @@ -42,7 +46,7 @@ func Test_ServerShutdown_Error(t *testing.T) { func Test_ServerStarts_Error(t *testing.T) { svr := testServer(t, nil) - err := svr.Start(nil) + err := svr.Start() assert.ErrorIs(t, err, ErrServerAlreadyStarted) err = svr.Shutdown("test")