refactor: http interfaces (#114)

* implement custom http handler interface

* implement trace_id

* normalize http method spacing for consistent logs

* fix failing test

* fix linter errors

* cleanup old dead code

* more route cleanup

* cleanup some inconsistent errors

* update and generate code

* make taskfile more consistent

* update task calls

* run tidy

* drop `@` tag for version

* use relative paths

* tidy

* fix auto-setting variables

* update build paths

* add contributing guide

* tidy
This commit is contained in:
Hayden 2022-10-29 18:15:35 -08:00 committed by GitHub
parent e2d93f8523
commit 6529549289
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 984 additions and 808 deletions

View file

@ -0,0 +1,23 @@
package server
import "errors"
type shutdownError struct {
message string
}
func (e *shutdownError) Error() string {
return e.message
}
// ShutdownError returns an error that indicates that the server has lost
// integrity and should be shut down.
func ShutdownError(message string) error {
return &shutdownError{message}
}
// IsShutdownError returns true if the error is a shutdown error.
func IsShutdownError(err error) bool {
var e *shutdownError
return errors.As(err, &e)
}

View file

@ -0,0 +1,25 @@
package server
import (
"net/http"
)
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
return f(w, r)
}
type Handler interface {
ServeHTTP(http.ResponseWriter, *http.Request) error
}
// ToHandler converts a function to a customer implementation of the Handler interface.
// that returns an error. This wrapper around the handler function and simply
// returns the nil in all cases
func ToHandler(handler http.Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
handler.ServeHTTP(w, r)
return nil
})
}

View file

@ -0,0 +1,38 @@
package server
import (
"net/http"
"strings"
)
type Middleware func(Handler) Handler
// wrapMiddleware creates a new handler by wrapping middleware around a final
// handler. The middlewares' Handlers will be executed by requests in the order
// they are provided.
func wrapMiddleware(mw []Middleware, handler Handler) Handler {
// Loop backwards through the middleware invoking each one. Replace the
// handler with the new wrapped handler. Looping backwards ensures that the
// first middleware of the slice is the first to be executed by requests.
for i := len(mw) - 1; i >= 0; i-- {
h := mw[i]
if h != nil {
handler = h(handler)
}
}
return handler
}
// StripTrailingSlash is a middleware that will strip trailing slashes from the request path.
//
// Example: /api/v1/ -> /api/v1
func StripTrailingSlash() Middleware {
return func(h Handler) Handler {
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
r.URL.Path = strings.TrimSuffix(r.URL.Path, "/")
return h.ServeHTTP(w, r)
})
}
}

103
backend/pkgs/server/mux.go Normal file
View file

@ -0,0 +1,103 @@
package server
import (
"context"
"net/http"
"github.com/google/uuid"
)
type vkey int
const (
// Key is the key for the server in the request context.
key vkey = 1
)
type Values struct {
TraceID string
}
func GetTraceID(ctx context.Context) string {
v, ok := ctx.Value(key).(Values)
if !ok {
return ""
}
return v.TraceID
}
func (s *Server) toHttpHandler(handler Handler, mw ...Middleware) http.HandlerFunc {
handler = wrapMiddleware(mw, handler)
handler = wrapMiddleware(s.mw, handler)
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
// Add the trace ID to the context
ctx = context.WithValue(ctx, key, Values{
TraceID: uuid.NewString(),
})
err := handler.ServeHTTP(w, r.WithContext(ctx))
if err != nil {
if IsShutdownError(err) {
_ = s.Shutdown("SIGTERM")
}
}
}
}
func (s *Server) handle(method, pattern string, handler Handler, mw ...Middleware) {
h := s.toHttpHandler(handler, mw...)
switch method {
case http.MethodGet:
s.mux.Get(pattern, h)
case http.MethodPost:
s.mux.Post(pattern, h)
case http.MethodPut:
s.mux.Put(pattern, h)
case http.MethodDelete:
s.mux.Delete(pattern, h)
case http.MethodPatch:
s.mux.Patch(pattern, h)
case http.MethodHead:
s.mux.Head(pattern, h)
case http.MethodOptions:
s.mux.Options(pattern, h)
}
}
func (s *Server) Get(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodGet, pattern, handler, mw...)
}
func (s *Server) Post(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodPost, pattern, handler, mw...)
}
func (s *Server) Put(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodPut, pattern, handler, mw...)
}
func (s *Server) Delete(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodDelete, pattern, handler, mw...)
}
func (s *Server) Patch(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodPatch, pattern, handler, mw...)
}
func (s *Server) Head(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodHead, pattern, handler, mw...)
}
func (s *Server) Options(pattern string, handler Handler, mw ...Middleware) {
s.handle(http.MethodOptions, pattern, handler, mw...)
}
func (s *Server) NotFound(handler Handler) {
s.mux.NotFound(s.toHttpHandler(handler))
}

View file

@ -16,7 +16,7 @@ func Decode(r *http.Request, val interface{}) error {
return nil
}
// GetId is a shotcut to get the id from the request URL or return a default value
// GetId is a shortcut to get the id from the request URL or return a default value
func GetParam(r *http.Request, key, d string) string {
val := r.URL.Query().Get(key)
@ -27,22 +27,22 @@ func GetParam(r *http.Request, key, d string) string {
return val
}
// GetSkip is a shotcut to get the skip from the request URL parameters
// GetSkip is a shortcut to get the skip from the request URL parameters
func GetSkip(r *http.Request, d string) string {
return GetParam(r, "skip", d)
}
// GetSkip is a shotcut to get the skip from the request URL parameters
// GetSkip is a shortcut to get the skip from the request URL parameters
func GetId(r *http.Request, d string) string {
return GetParam(r, "id", d)
}
// GetLimit is a shotcut to get the limit from the request URL parameters
// GetLimit is a shortcut to get the limit from the request URL parameters
func GetLimit(r *http.Request, d string) string {
return GetParam(r, "limit", d)
}
// GetQuery is a shotcut to get the sort from the request URL parameters
// GetQuery is a shortcut to get the sort from the request URL parameters
func GetQuery(r *http.Request, d string) string {
return GetParam(r, "query", d)
}

View file

@ -2,16 +2,20 @@ package server
import (
"encoding/json"
"errors"
"net/http"
)
type ErrorResponse struct {
Error string `json:"error"`
Fields map[string]string `json:"fields,omitempty"`
}
// Respond converts a Go value to JSON and sends it to the client.
// Adapted from https://github.com/ardanlabs/service/tree/master/foundation/web
func Respond(w http.ResponseWriter, statusCode int, data interface{}) {
func Respond(w http.ResponseWriter, statusCode int, data interface{}) error {
if statusCode == http.StatusNoContent {
w.WriteHeader(statusCode)
return
return nil
}
// Convert the response value to JSON.
@ -28,31 +32,8 @@ func Respond(w http.ResponseWriter, statusCode int, data interface{}) {
// Send the result back to the client.
if _, err := w.Write(jsonData); err != nil {
panic(err)
return err
}
}
// ResponseError is a helper function that sends a JSON response of an error message
func RespondError(w http.ResponseWriter, statusCode int, err error) {
eb := ErrorBuilder{}
eb.AddError(err)
eb.Respond(w, statusCode)
}
// RespondServerError is a wrapper around RespondError that sends a 500 internal server error. Useful for
// Sending generic errors when everything went wrong.
func RespondServerError(w http.ResponseWriter) {
RespondError(w, http.StatusInternalServerError, errors.New("internal server error"))
}
// RespondNotFound is a helper utility for responding with a generic
// "unauthorized" error.
func RespondUnauthorized(w http.ResponseWriter) {
RespondError(w, http.StatusUnauthorized, errors.New("unauthorized"))
}
// RespondForbidden is a helper utility for responding with a generic
// "forbidden" error.
func RespondForbidden(w http.ResponseWriter) {
RespondError(w, http.StatusForbidden, errors.New("forbidden"))
return nil
}

View file

@ -1,76 +0,0 @@
package server
import (
"net/http"
)
type ValidationError struct {
Field string `json:"field"`
Reason string `json:"reason"`
}
type ValidationErrors []ValidationError
func (ve *ValidationErrors) HasErrors() bool {
if (ve == nil) || (len(*ve) == 0) {
return false
}
for _, err := range *ve {
if err.Field != "" {
return true
}
}
return false
}
func (ve ValidationErrors) Append(field, reasons string) ValidationErrors {
return append(ve, ValidationError{
Field: field,
Reason: reasons,
})
}
// ErrorBuilder is a helper type to build a response that contains an array of errors.
// Typical use cases are for returning an array of validation errors back to the user.
//
// Example:
//
// {
// "errors": [
// "invalid id",
// "invalid name",
// "invalid description"
// ],
// "message": "Unprocessable Entity",
// "status": 422
// }
type ErrorBuilder struct {
errs []string
}
// HasErrors returns true if the ErrorBuilder has any errors.
func (eb *ErrorBuilder) HasErrors() bool {
if (eb.errs == nil) || (len(eb.errs) == 0) {
return false
}
return true
}
// AddError adds an error to the ErrorBuilder if an error is not nil. If the
// Error is nil, then nothing is added.
func (eb *ErrorBuilder) AddError(err error) {
if err != nil {
if eb.errs == nil {
eb.errs = make([]string, 0)
}
eb.errs = append(eb.errs, err.Error())
}
}
// Respond sends a JSON response with the ErrorBuilder's errors. If there are no errors, then
// the errors field will be an empty array.
func (eb *ErrorBuilder) Respond(w http.ResponseWriter, statusCode int) {
Respond(w, statusCode, Wrap(nil).AddError(http.StatusText(statusCode), eb.errs))
}

View file

@ -1,107 +0,0 @@
package server
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/hay-kot/homebox/backend/pkgs/faker"
"github.com/stretchr/testify/assert"
)
func Test_ErrorBuilder_HasErrors_NilList(t *testing.T) {
t.Parallel()
var ebNilList = ErrorBuilder{}
assert.False(t, ebNilList.HasErrors(), "ErrorBuilder.HasErrors() should return false when list is nil")
}
func Test_ErrorBuilder_HasErrors_EmptyList(t *testing.T) {
t.Parallel()
var ebEmptyList = ErrorBuilder{
errs: []string{},
}
assert.False(t, ebEmptyList.HasErrors(), "ErrorBuilder.HasErrors() should return false when list is empty")
}
func Test_ErrorBuilder_HasErrors_WithError(t *testing.T) {
t.Parallel()
var ebList = ErrorBuilder{}
ebList.AddError(errors.New("test error"))
assert.True(t, ebList.HasErrors(), "ErrorBuilder.HasErrors() should return true when list is not empty")
}
func Test_ErrorBuilder_AddError(t *testing.T) {
t.Parallel()
randomError := make([]error, 10)
f := faker.NewFaker()
errorStrings := make([]string, 10)
for i := 0; i < 10; i++ {
err := errors.New(f.Str(10))
randomError[i] = err
errorStrings[i] = err.Error()
}
// Check Results
var ebList = ErrorBuilder{}
for _, err := range randomError {
ebList.AddError(err)
}
assert.Equal(t, errorStrings, ebList.errs, "ErrorBuilder.AddError() should add an error to the list")
}
func Test_ErrorBuilder_Respond(t *testing.T) {
t.Parallel()
f := faker.NewFaker()
randomError := make([]error, 5)
for i := 0; i < 5; i++ {
err := errors.New(f.Str(5))
randomError[i] = err
}
// Check Results
var ebList = ErrorBuilder{}
for _, err := range randomError {
ebList.AddError(err)
}
fakeWriter := httptest.NewRecorder()
ebList.Respond(fakeWriter, 422)
assert.Equal(t, 422, fakeWriter.Code, "ErrorBuilder.Respond() should return a status code of 422")
// Check errors payload is correct
errorsStruct := struct {
Errors []string `json:"details"`
Message string `json:"message"`
Error bool `json:"error"`
}{
Errors: ebList.errs,
Message: http.StatusText(http.StatusUnprocessableEntity),
Error: true,
}
asJson, _ := json.Marshal(errorsStruct)
assert.JSONEq(t, string(asJson), fakeWriter.Body.String(), "ErrorBuilder.Respond() should return a JSON response with the errors")
}

View file

@ -1,7 +1,6 @@
package server
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
@ -17,7 +16,8 @@ func Test_Respond_NoContent(t *testing.T) {
Name: "dummy",
}
Respond(recorder, http.StatusNoContent, dummystruct)
err := Respond(recorder, http.StatusNoContent, dummystruct)
assert.NoError(t, err)
assert.Equal(t, http.StatusNoContent, recorder.Code)
assert.Empty(t, recorder.Body.String())
@ -31,48 +31,11 @@ func Test_Respond_JSON(t *testing.T) {
Name: "dummy",
}
Respond(recorder, http.StatusCreated, dummystruct)
err := Respond(recorder, http.StatusCreated, dummystruct)
assert.NoError(t, err)
assert.Equal(t, http.StatusCreated, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"name":"dummy"}`)
assert.Equal(t, "application/json", recorder.Header().Get("Content-Type"))
}
func Test_RespondError(t *testing.T) {
recorder := httptest.NewRecorder()
var customError = errors.New("custom error")
RespondError(recorder, http.StatusBadRequest, customError)
assert.Equal(t, http.StatusBadRequest, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"details":["custom error"], "message":"Bad Request", "error":true}`)
}
func Test_RespondInternalServerError(t *testing.T) {
recorder := httptest.NewRecorder()
RespondServerError(recorder)
assert.Equal(t, http.StatusInternalServerError, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"details":["internal server error"], "message":"Internal Server Error", "error":true}`)
}
func Test_RespondUnauthorized(t *testing.T) {
recorder := httptest.NewRecorder()
RespondUnauthorized(recorder)
assert.Equal(t, http.StatusUnauthorized, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"details":["unauthorized"], "message":"Unauthorized", "error":true}`)
}
func Test_RespondForbidden(t *testing.T) {
recorder := httptest.NewRecorder()
RespondForbidden(recorder)
assert.Equal(t, http.StatusForbidden, recorder.Code)
assert.JSONEq(t, recorder.Body.String(), `{"details":["forbidden"], "message":"Forbidden", "error":true}`)
}

View file

@ -17,15 +17,3 @@ func Wrap(data interface{}) Result {
Item: data,
}
}
func (r Result) AddMessage(message string) Result {
r.Message = message
return r
}
func (r Result) AddError(err string, details interface{}) Result {
r.Message = err
r.Details = details
r.Error = true
return r
}

View file

@ -10,6 +10,8 @@ import (
"sync"
"syscall"
"time"
"github.com/go-chi/chi/v5"
)
var (
@ -22,7 +24,11 @@ type Server struct {
Port string
Worker Worker
wg sync.WaitGroup
wg sync.WaitGroup
mux *chi.Mux
// mw is the global middleware chain for the server.
mw []Middleware
started bool
activeServer *http.Server
@ -36,6 +42,7 @@ func NewServer(opts ...Option) *Server {
s := &Server{
Host: "localhost",
Port: "8080",
mux: chi.NewRouter(),
Worker: NewSimpleWorker(),
idleTimeout: 30 * time.Second,
readTimeout: 10 * time.Second,
@ -75,14 +82,14 @@ func (s *Server) Shutdown(sig string) error {
}
func (s *Server) Start(router http.Handler) error {
func (s *Server) Start() error {
if s.started {
return ErrServerAlreadyStarted
}
s.activeServer = &http.Server{
Addr: s.Host + ":" + s.Port,
Handler: router,
Handler: s.mux,
IdleTimeout: s.idleTimeout,
ReadTimeout: s.readTimeout,
WriteTimeout: s.writeTimeout,

View file

@ -4,6 +4,13 @@ import "time"
type Option = func(s *Server) error
func WithMiddleware(mw ...Middleware) Option {
return func(s *Server) error {
s.mw = append(s.mw, mw...)
return nil
}
}
func WithWorker(w Worker) Option {
return func(s *Server) error {
s.Worker = w

View file

@ -12,8 +12,11 @@ import (
func testServer(t *testing.T, r http.Handler) *Server {
svr := NewServer(WithHost("127.0.0.1"), WithPort("19245"))
if r != nil {
svr.mux.Mount("/", r)
}
go func() {
err := svr.Start(r)
err := svr.Start()
assert.NoError(t, err)
}()
@ -42,7 +45,7 @@ func Test_ServerShutdown_Error(t *testing.T) {
func Test_ServerStarts_Error(t *testing.T) {
svr := testServer(t, nil)
err := svr.Start(nil)
err := svr.Start()
assert.ErrorIs(t, err, ErrServerAlreadyStarted)
err = svr.Shutdown("test")