homebox/backend/app/api/middleware.go

106 lines
2.9 KiB
Go
Raw Normal View History

2022-08-30 02:30:36 +00:00
package main
import (
"context"
"errors"
2022-08-30 02:30:36 +00:00
"net/http"
"strings"
"github.com/hay-kot/homebox/backend/internal/core/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/server"
2022-08-30 02:30:36 +00:00
)
type tokenHasKey struct {
key string
}
var (
hashedToken = tokenHasKey{key: "hashedToken"}
)
type RoleMode int
const (
RoleModeOr RoleMode = 0
RoleModeAnd RoleMode = 1
)
// mwRoles is a middleware that will validate the required roles are met. All roles
// are required to be met for the request to be allowed. If the user does not have
// the required roles, a 403 Forbidden will be returned.
//
// WARNING: This middleware _MUST_ be called after mwAuthToken or else it will panic
func (a *app) mwRoles(rm RoleMode, required ...string) server.Middleware {
return func(next server.Handler) server.Handler {
return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
maybeToken := ctx.Value(hashedToken)
if maybeToken == nil {
panic("mwRoles: token not found in context, you must call mwAuthToken before mwRoles")
}
token := maybeToken.(string)
roles, err := a.repos.AuthTokens.GetRoles(r.Context(), token)
if err != nil {
return err
}
outer:
switch rm {
case RoleModeOr:
for _, role := range required {
if roles.Contains(role) {
break outer
}
}
return validate.NewRequestError(errors.New("Forbidden"), http.StatusForbidden)
case RoleModeAnd:
for _, req := range required {
if !roles.Contains(req) {
return validate.NewRequestError(errors.New("Unauthorized"), http.StatusForbidden)
}
}
}
return next.ServeHTTP(w, r)
})
}
}
2022-08-30 02:30:36 +00:00
// mwAuthToken is a middleware that will check the database for a stateful token
// and attach it's user to the request context, or return an appropriate error.
// Authorization support is by token via Headers or Query Parameter
//
// Example:
// - header = "Bearer 1234567890"
// - query = "?access_token=1234567890"
func (a *app) mwAuthToken(next server.Handler) server.Handler {
return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
2022-08-30 02:30:36 +00:00
requestToken := r.Header.Get("Authorization")
if requestToken == "" {
// check for query param
requestToken = r.URL.Query().Get("access_token")
if requestToken == "" {
return validate.NewRequestError(errors.New("Authorization header or query is required"), http.StatusUnauthorized)
}
2022-08-30 02:30:36 +00:00
}
requestToken = strings.TrimPrefix(requestToken, "Bearer ")
r = r.WithContext(context.WithValue(r.Context(), hashedToken, requestToken))
2022-08-31 02:11:23 +00:00
usr, err := a.services.User.GetSelf(r.Context(), requestToken)
2022-08-30 02:30:36 +00:00
// Check the database for the token
if err != nil {
return validate.NewRequestError(errors.New("Authorization header is required"), http.StatusUnauthorized)
2022-08-30 02:30:36 +00:00
}
r = r.WithContext(services.SetUserCtx(r.Context(), &usr, requestToken))
return next.ServeHTTP(w, r)
2022-08-30 02:30:36 +00:00
})
}