forked from mirrors/homebox
chore: refactor api endpoints (#339)
* move typegen code * update taskfile to fix code-gen caches and use 'dir' attribute * enable dumping stack traces for errors * log request start and stop * set zerolog stack handler * fix routes function * refactor context adapters to use requests directly * change some method signatures to support GID * start requiring validation tags * first pass on updating handlers to use adapters * add errs package * code gen * tidy * rework API to use external server package
This commit is contained in:
parent
184b494fc3
commit
db80f8a159
56 changed files with 806 additions and 1947 deletions
|
@ -1,8 +0,0 @@
|
|||
package server
|
||||
|
||||
const (
|
||||
ContentType = "Content-Type"
|
||||
ContentJSON = "application/json"
|
||||
ContentXML = "application/xml"
|
||||
ContentFormUrlEncoded = "application/x-www-form-urlencoded"
|
||||
)
|
|
@ -1,23 +0,0 @@
|
|||
package server
|
||||
|
||||
import "errors"
|
||||
|
||||
type shutdownError struct {
|
||||
message string
|
||||
}
|
||||
|
||||
func (e *shutdownError) Error() string {
|
||||
return e.message
|
||||
}
|
||||
|
||||
// ShutdownError returns an error that indicates that the server has lost
|
||||
// integrity and should be shut down.
|
||||
func ShutdownError(message string) error {
|
||||
return &shutdownError{message}
|
||||
}
|
||||
|
||||
// IsShutdownError returns true if the error is a shutdown error.
|
||||
func IsShutdownError(err error) bool {
|
||||
var e *shutdownError
|
||||
return errors.As(err, &e)
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
func (f HandlerFunc) ServeHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
return f(w, r)
|
||||
}
|
||||
|
||||
type Handler interface {
|
||||
ServeHTTP(http.ResponseWriter, *http.Request) error
|
||||
}
|
||||
|
||||
// ToHandler converts a function to a customer implementation of the Handler interface.
|
||||
// that returns an error. This wrapper around the handler function and simply
|
||||
// returns the nil in all cases
|
||||
func ToHandler(handler http.Handler) Handler {
|
||||
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
handler.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Middleware func(Handler) Handler
|
||||
|
||||
// wrapMiddleware creates a new handler by wrapping middleware around a final
|
||||
// handler. The middlewares' Handlers will be executed by requests in the order
|
||||
// they are provided.
|
||||
func wrapMiddleware(mw []Middleware, handler Handler) Handler {
|
||||
// Loop backwards through the middleware invoking each one. Replace the
|
||||
// handler with the new wrapped handler. Looping backwards ensures that the
|
||||
// first middleware of the slice is the first to be executed by requests.
|
||||
for i := len(mw) - 1; i >= 0; i-- {
|
||||
h := mw[i]
|
||||
if h != nil {
|
||||
handler = h(handler)
|
||||
}
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
// StripTrailingSlash is a middleware that will strip trailing slashes from the request path.
|
||||
//
|
||||
// Example: /api/v1/ -> /api/v1
|
||||
func StripTrailingSlash() Middleware {
|
||||
return func(h Handler) Handler {
|
||||
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
r.URL.Path = strings.TrimSuffix(r.URL.Path, "/")
|
||||
return h.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,102 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type vkey int
|
||||
|
||||
const (
|
||||
// Key is the key for the server in the request context.
|
||||
key vkey = 1
|
||||
)
|
||||
|
||||
type Values struct {
|
||||
TraceID string
|
||||
}
|
||||
|
||||
func GetTraceID(ctx context.Context) string {
|
||||
v, ok := ctx.Value(key).(Values)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return v.TraceID
|
||||
}
|
||||
|
||||
func (s *Server) toHttpHandler(handler Handler, mw ...Middleware) http.HandlerFunc {
|
||||
handler = wrapMiddleware(mw, handler)
|
||||
|
||||
handler = wrapMiddleware(s.mw, handler)
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
// Add the trace ID to the context
|
||||
ctx = context.WithValue(ctx, key, Values{
|
||||
TraceID: uuid.NewString(),
|
||||
})
|
||||
|
||||
err := handler.ServeHTTP(w, r.WithContext(ctx))
|
||||
if err != nil {
|
||||
if IsShutdownError(err) {
|
||||
_ = s.Shutdown("SIGTERM")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handle(method, pattern string, handler Handler, mw ...Middleware) {
|
||||
h := s.toHttpHandler(handler, mw...)
|
||||
|
||||
switch method {
|
||||
case http.MethodGet:
|
||||
s.mux.Get(pattern, h)
|
||||
case http.MethodPost:
|
||||
s.mux.Post(pattern, h)
|
||||
case http.MethodPut:
|
||||
s.mux.Put(pattern, h)
|
||||
case http.MethodDelete:
|
||||
s.mux.Delete(pattern, h)
|
||||
case http.MethodPatch:
|
||||
s.mux.Patch(pattern, h)
|
||||
case http.MethodHead:
|
||||
s.mux.Head(pattern, h)
|
||||
case http.MethodOptions:
|
||||
s.mux.Options(pattern, h)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Get(pattern string, handler Handler, mw ...Middleware) {
|
||||
s.handle(http.MethodGet, pattern, handler, mw...)
|
||||
}
|
||||
|
||||
func (s *Server) Post(pattern string, handler Handler, mw ...Middleware) {
|
||||
s.handle(http.MethodPost, pattern, handler, mw...)
|
||||
}
|
||||
|
||||
func (s *Server) Put(pattern string, handler Handler, mw ...Middleware) {
|
||||
s.handle(http.MethodPut, pattern, handler, mw...)
|
||||
}
|
||||
|
||||
func (s *Server) Delete(pattern string, handler Handler, mw ...Middleware) {
|
||||
s.handle(http.MethodDelete, pattern, handler, mw...)
|
||||
}
|
||||
|
||||
func (s *Server) Patch(pattern string, handler Handler, mw ...Middleware) {
|
||||
s.handle(http.MethodPatch, pattern, handler, mw...)
|
||||
}
|
||||
|
||||
func (s *Server) Head(pattern string, handler Handler, mw ...Middleware) {
|
||||
s.handle(http.MethodHead, pattern, handler, mw...)
|
||||
}
|
||||
|
||||
func (s *Server) Options(pattern string, handler Handler, mw ...Middleware) {
|
||||
s.handle(http.MethodOptions, pattern, handler, mw...)
|
||||
}
|
||||
|
||||
func (s *Server) NotFound(handler Handler) {
|
||||
s.mux.NotFound(s.toHttpHandler(handler))
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// Decode reads the body of an HTTP request looking for a JSON document. The
|
||||
// body is decoded into the provided value.
|
||||
func Decode(r *http.Request, val interface{}) error {
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
// decoder.DisallowUnknownFields()
|
||||
if err := decoder.Decode(val); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetId is a shortcut to get the id from the request URL or return a default value
|
||||
func GetParam(r *http.Request, key, d string) string {
|
||||
val := r.URL.Query().Get(key)
|
||||
|
||||
if val == "" {
|
||||
return d
|
||||
}
|
||||
|
||||
return val
|
||||
}
|
||||
|
||||
// GetSkip is a shortcut to get the skip from the request URL parameters
|
||||
func GetSkip(r *http.Request, d string) string {
|
||||
return GetParam(r, "skip", d)
|
||||
}
|
||||
|
||||
// GetSkip is a shortcut to get the skip from the request URL parameters
|
||||
func GetId(r *http.Request, d string) string {
|
||||
return GetParam(r, "id", d)
|
||||
}
|
||||
|
||||
// GetLimit is a shortcut to get the limit from the request URL parameters
|
||||
func GetLimit(r *http.Request, d string) string {
|
||||
return GetParam(r, "limit", d)
|
||||
}
|
||||
|
||||
// GetQuery is a shortcut to get the sort from the request URL parameters
|
||||
func GetQuery(r *http.Request, d string) string {
|
||||
return GetParam(r, "query", d)
|
||||
}
|
|
@ -1,210 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type TestStruct struct {
|
||||
Name string `json:"name"`
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
func TestDecode(t *testing.T) {
|
||||
type args struct {
|
||||
r *http.Request
|
||||
val interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "check_error",
|
||||
args: args{
|
||||
r: &http.Request{
|
||||
Body: http.NoBody,
|
||||
},
|
||||
val: make(map[string]interface{}),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "check_success",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
val: TestStruct{
|
||||
Name: "test",
|
||||
Data: "test",
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := Decode(tt.args.r, &tt.args.val); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Decode() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetParam(t *testing.T) {
|
||||
type args struct {
|
||||
r *http.Request
|
||||
key string
|
||||
d string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "check_default",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
key: "id",
|
||||
d: "default",
|
||||
},
|
||||
want: "default",
|
||||
},
|
||||
{
|
||||
name: "check_id",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/item?id=123", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
key: "id",
|
||||
d: "",
|
||||
},
|
||||
want: "123",
|
||||
},
|
||||
{
|
||||
name: "check_query",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/item?query=hello-world", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
key: "query",
|
||||
d: "",
|
||||
},
|
||||
want: "hello-world",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GetParam(tt.args.r, tt.args.key, tt.args.d); got != tt.want {
|
||||
t.Errorf("GetParam() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSkip(t *testing.T) {
|
||||
type args struct {
|
||||
r *http.Request
|
||||
d string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "check_default",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
d: "0",
|
||||
},
|
||||
want: "0",
|
||||
},
|
||||
{
|
||||
name: "check_skip",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/item?skip=107", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
d: "0",
|
||||
},
|
||||
want: "107",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GetSkip(tt.args.r, tt.args.d); got != tt.want {
|
||||
t.Errorf("GetSkip() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetLimit(t *testing.T) {
|
||||
type args struct {
|
||||
r *http.Request
|
||||
d string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "check_default",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
d: "0",
|
||||
},
|
||||
want: "0",
|
||||
},
|
||||
{
|
||||
name: "check_limit",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/item?limit=107", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
d: "0",
|
||||
},
|
||||
want: "107",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GetLimit(tt.args.r, tt.args.d); got != tt.want {
|
||||
t.Errorf("GetLimit() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetQuery(t *testing.T) {
|
||||
type args struct {
|
||||
r *http.Request
|
||||
d string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "check_default",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
d: "0",
|
||||
},
|
||||
want: "0",
|
||||
},
|
||||
{
|
||||
name: "check_query",
|
||||
args: args{
|
||||
r: httptest.NewRequest("POST", "/item?query=hello-query", strings.NewReader(`{"name":"test","data":"test"}`)),
|
||||
d: "0",
|
||||
},
|
||||
want: "hello-query",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := GetQuery(tt.args.r, tt.args.d); got != tt.want {
|
||||
t.Errorf("GetQuery() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
|
@ -1,39 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type ErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
Fields map[string]string `json:"fields,omitempty"`
|
||||
}
|
||||
|
||||
// Respond converts a Go value to JSON and sends it to the client.
|
||||
// Adapted from https://github.com/ardanlabs/service/tree/master/foundation/web
|
||||
func Respond(w http.ResponseWriter, statusCode int, data interface{}) error {
|
||||
if statusCode == http.StatusNoContent {
|
||||
w.WriteHeader(statusCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert the response value to JSON.
|
||||
jsonData, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Set the content type and headers once we know marshaling has succeeded.
|
||||
w.Header().Set("Content-Type", ContentJSON)
|
||||
|
||||
// Write the status code to the response.
|
||||
w.WriteHeader(statusCode)
|
||||
|
||||
// Send the result back to the client.
|
||||
if _, err := w.Write(jsonData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func Test_Respond_NoContent(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
dummystruct := struct {
|
||||
Name string
|
||||
}{
|
||||
Name: "dummy",
|
||||
}
|
||||
|
||||
err := Respond(recorder, http.StatusNoContent, dummystruct)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, recorder.Code)
|
||||
assert.Empty(t, recorder.Body.String())
|
||||
}
|
||||
|
||||
func Test_Respond_JSON(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
dummystruct := struct {
|
||||
Name string `json:"name"`
|
||||
}{
|
||||
Name: "dummy",
|
||||
}
|
||||
|
||||
err := Respond(recorder, http.StatusCreated, dummystruct)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, recorder.Code)
|
||||
assert.JSONEq(t, recorder.Body.String(), `{"name":"dummy"}`)
|
||||
assert.Equal(t, "application/json", recorder.Header().Get("Content-Type"))
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
package server
|
||||
|
||||
type Result struct {
|
||||
Error bool `json:"error,omitempty"`
|
||||
Details interface{} `json:"details,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Item interface{} `json:"item,omitempty"`
|
||||
}
|
||||
|
||||
type Results struct {
|
||||
Items any `json:"items"`
|
||||
}
|
||||
|
||||
// Wrap creates a Wrapper instance and adds the initial namespace and data to be returned.
|
||||
func Wrap(data interface{}) Result {
|
||||
return Result{
|
||||
Item: data,
|
||||
}
|
||||
}
|
|
@ -1,144 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrServerNotStarted = errors.New("server not started")
|
||||
ErrServerAlreadyStarted = errors.New("server already started")
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Host string
|
||||
Port string
|
||||
Worker Worker
|
||||
|
||||
wg sync.WaitGroup
|
||||
mux *chi.Mux
|
||||
|
||||
// mw is the global middleware chain for the server.
|
||||
mw []Middleware
|
||||
|
||||
started bool
|
||||
activeServer *http.Server
|
||||
|
||||
idleTimeout time.Duration
|
||||
readTimeout time.Duration
|
||||
writeTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewServer(opts ...Option) *Server {
|
||||
s := &Server{
|
||||
Host: "localhost",
|
||||
Port: "8080",
|
||||
mux: chi.NewRouter(),
|
||||
Worker: NewSimpleWorker(),
|
||||
idleTimeout: 30 * time.Second,
|
||||
readTimeout: 10 * time.Second,
|
||||
writeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
err := opt(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *Server) Shutdown(sig string) error {
|
||||
if !s.started {
|
||||
return ErrServerNotStarted
|
||||
}
|
||||
fmt.Printf("Received %s signal, shutting down\n", sig)
|
||||
|
||||
// Create a context with a 5-second timeout.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := s.activeServer.Shutdown(ctx)
|
||||
s.started = false
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("Http server shutdown, waiting for all tasks to finish")
|
||||
s.wg.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
if s.started {
|
||||
return ErrServerAlreadyStarted
|
||||
}
|
||||
|
||||
s.activeServer = &http.Server{
|
||||
Addr: s.Host + ":" + s.Port,
|
||||
Handler: s.mux,
|
||||
IdleTimeout: s.idleTimeout,
|
||||
ReadTimeout: s.readTimeout,
|
||||
WriteTimeout: s.writeTimeout,
|
||||
}
|
||||
|
||||
shutdownError := make(chan error)
|
||||
|
||||
go func() {
|
||||
// Create a quit channel which carries os.Signal values.
|
||||
quit := make(chan os.Signal, 1)
|
||||
|
||||
// Use signal.Notify() to listen for incoming SIGINT and SIGTERM signals and
|
||||
// relay them to the quit channel.
|
||||
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Read the signal from the quit channel. block until received
|
||||
sig := <-quit
|
||||
|
||||
err := s.Shutdown(sig.String())
|
||||
if err != nil {
|
||||
shutdownError <- err
|
||||
}
|
||||
|
||||
// Exit the application with a 0 (success) status code.
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
s.started = true
|
||||
err := s.activeServer.ListenAndServe()
|
||||
|
||||
if !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
|
||||
err = <-shutdownError
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Println("Server shutdown successfully")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Background starts a go routine that runs on the servers pool. In the event of a shutdown
|
||||
// request, the server will wait until all open goroutines have finished before shutting down.
|
||||
func (svr *Server) Background(task func()) {
|
||||
svr.wg.Add(1)
|
||||
svr.Worker.Add(func() {
|
||||
defer svr.wg.Done()
|
||||
task()
|
||||
})
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
package server
|
||||
|
||||
import "time"
|
||||
|
||||
type Option = func(s *Server) error
|
||||
|
||||
func WithMiddleware(mw ...Middleware) Option {
|
||||
return func(s *Server) error {
|
||||
s.mw = append(s.mw, mw...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithWorker(w Worker) Option {
|
||||
return func(s *Server) error {
|
||||
s.Worker = w
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithHost(host string) Option {
|
||||
return func(s *Server) error {
|
||||
s.Host = host
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithPort(port string) Option {
|
||||
return func(s *Server) error {
|
||||
s.Port = port
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithReadTimeout(seconds int) Option {
|
||||
return func(s *Server) error {
|
||||
s.readTimeout = time.Duration(seconds) * time.Second
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithWriteTimeout(seconds int) Option {
|
||||
return func(s *Server) error {
|
||||
s.writeTimeout = time.Duration(seconds) * time.Second
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithIdleTimeout(seconds int) Option {
|
||||
return func(s *Server) error {
|
||||
s.idleTimeout = time.Duration(seconds) * time.Second
|
||||
return nil
|
||||
}
|
||||
}
|
|
@ -1,101 +0,0 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func testServer(t *testing.T, r http.Handler) *Server {
|
||||
svr := NewServer(WithHost("127.0.0.1"), WithPort("19245"))
|
||||
|
||||
if r != nil {
|
||||
svr.mux.Mount("/", r)
|
||||
}
|
||||
go func() {
|
||||
err := svr.Start()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ping := func() error {
|
||||
_, err := http.Get("http://127.0.0.1:19245")
|
||||
return err
|
||||
}
|
||||
|
||||
for {
|
||||
if err := ping(); err == nil {
|
||||
break
|
||||
}
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
}
|
||||
|
||||
return svr
|
||||
}
|
||||
|
||||
func Test_ServerShutdown_Error(t *testing.T) {
|
||||
svr := NewServer(WithHost("127.0.0.1"), WithPort("19245"))
|
||||
|
||||
err := svr.Shutdown("test")
|
||||
assert.ErrorIs(t, err, ErrServerNotStarted)
|
||||
}
|
||||
|
||||
func Test_ServerStarts_Error(t *testing.T) {
|
||||
svr := testServer(t, nil)
|
||||
|
||||
err := svr.Start()
|
||||
assert.ErrorIs(t, err, ErrServerAlreadyStarted)
|
||||
|
||||
err = svr.Shutdown("test")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_ServerStarts(t *testing.T) {
|
||||
svr := testServer(t, nil)
|
||||
err := svr.Shutdown("test")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_GracefulServerShutdownWithWorkers(t *testing.T) {
|
||||
isFinished := false
|
||||
|
||||
svr := testServer(t, nil)
|
||||
|
||||
svr.Background(func() {
|
||||
time.Sleep(time.Second * 4)
|
||||
isFinished = true
|
||||
})
|
||||
|
||||
err := svr.Shutdown("test")
|
||||
|
||||
assert.NoError(t, err)
|
||||
assert.True(t, isFinished)
|
||||
}
|
||||
|
||||
func Test_GracefulServerShutdownWithRequests(t *testing.T) {
|
||||
var isFinished atomic.Bool
|
||||
|
||||
router := http.NewServeMux()
|
||||
|
||||
// add long running handler func
|
||||
router.HandleFunc("/test", func(rw http.ResponseWriter, r *http.Request) {
|
||||
time.Sleep(time.Second * 3)
|
||||
isFinished.Store(true)
|
||||
})
|
||||
|
||||
svr := testServer(t, router)
|
||||
|
||||
// Make request to "/test"
|
||||
go func() {
|
||||
_, _ = http.Get("http://127.0.0.1:19245/test") // This is probably bad?
|
||||
}()
|
||||
|
||||
time.Sleep(time.Second) // Hack to wait for the request to be made
|
||||
|
||||
err := svr.Shutdown("test")
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.True(t, isFinished.Load())
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
package server
|
||||
|
||||
// TODO: #2 Implement Go routine pool/job queue
|
||||
|
||||
type Worker interface {
|
||||
Add(func())
|
||||
}
|
||||
|
||||
// SimpleWorker is a simple background worker that implements
|
||||
// the Worker interface and runs all tasks in a go routine without
|
||||
// a pool or que or limits. It's useful for simple or small applications
|
||||
// with minimal/short background tasks
|
||||
type SimpleWorker struct{}
|
||||
|
||||
func NewSimpleWorker() *SimpleWorker {
|
||||
return &SimpleWorker{}
|
||||
}
|
||||
|
||||
func (sw *SimpleWorker) Add(task func()) {
|
||||
go task()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue