feat: extract auth into provider (#663)

* extract auth into provider

* bump go version

* use pointer

* rebase
This commit is contained in:
Hayden 2023-12-12 08:49:46 -06:00 committed by GitHub
parent 522943687e
commit 8538877f52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 155 additions and 46 deletions

View file

@ -58,6 +58,25 @@ func GetCookies(r *http.Request) (*CookieContents, error) {
}, nil
}
// AuthProvider is an interface that can be implemented by any authentication provider.
// to extend authentication methods for the API.
type AuthProvider interface {
// Name returns the name of the authentication provider. This should be a unique name.
// that is URL friendly.
//
// Example: "local", "ldap"
Name() string
// Authenticate is called when a user attempts to login to the API. The implementation
// should return an error if the user cannot be authenticated. If an error is returned
// the API controller will return a vague error message to the user.
//
// Authenticate should do the following:
//
// 1. Ensure that the user exists within the database (either create, or get)
// 2. On successful authentication, they must set the user cookies.
Authenticate(w http.ResponseWriter, r *http.Request) (services.UserAuthTokenDetail, error)
}
// HandleAuthLogin godoc
//
// @Summary User Login
@ -66,53 +85,42 @@ func GetCookies(r *http.Request) (*CookieContents, error) {
// @Accept application/json
// @Param username formData string false "string" example(admin@admin.com)
// @Param password formData string false "string" example(admin)
// @Param payload body LoginForm true "Login Data"
// @Param payload body LoginForm true "Login Data"
// @Param provider query string false "auth provider"
// @Produce json
// @Success 200 {object} TokenResponse
// @Router /v1/users/login [POST]
func (ctrl *V1Controller) HandleAuthLogin() errchain.HandlerFunc {
func (ctrl *V1Controller) HandleAuthLogin(ps ...AuthProvider) errchain.HandlerFunc {
if len(ps) == 0 {
panic("no auth providers provided")
}
providers := make(map[string]AuthProvider)
for _, p := range ps {
log.Info().Str("name", p.Name()).Msg("registering auth provider")
providers[p.Name()] = p
}
return func(w http.ResponseWriter, r *http.Request) error {
loginForm := &LoginForm{}
switch r.Header.Get("Content-Type") {
case "application/x-www-form-urlencoded":
err := r.ParseForm()
if err != nil {
return errors.New("failed to parse form")
}
loginForm.Username = r.PostFormValue("username")
loginForm.Password = r.PostFormValue("password")
loginForm.StayLoggedIn = r.PostFormValue("stayLoggedIn") == "true"
case "application/json":
err := server.Decode(r, loginForm)
if err != nil {
log.Err(err).Msg("failed to decode login form")
return errors.New("failed to decode login form")
}
default:
return server.JSON(w, http.StatusBadRequest, errors.New("invalid content type"))
// Extract provider query
provider := r.URL.Query().Get("provider")
if provider == "" {
provider = "local"
}
if loginForm.Username == "" || loginForm.Password == "" {
return validate.NewFieldErrors(
validate.FieldError{
Field: "username",
Error: "username or password is empty",
},
validate.FieldError{
Field: "password",
Error: "username or password is empty",
},
)
// Get the provider
p, ok := providers[provider]
if !ok {
return validate.NewRequestError(errors.New("invalid auth provider"), http.StatusBadRequest)
}
newToken, err := ctrl.svc.User.Login(r.Context(), strings.ToLower(loginForm.Username), loginForm.Password, loginForm.StayLoggedIn)
newToken, err := p.Authenticate(w, r)
if err != nil {
return validate.NewRequestError(errors.New("authentication failed"), http.StatusInternalServerError)
log.Err(err).Msg("failed to authenticate")
return server.JSON(w, http.StatusInternalServerError, err.Error())
}
ctrl.setCookies(w, noPort(r.Host), newToken.Raw, newToken.ExpiresAt, loginForm.StayLoggedIn)
ctrl.setCookies(w, noPort(r.Host), newToken.Raw, newToken.ExpiresAt, true)
return server.JSON(w, http.StatusOK, TokenResponse{
Token: "Bearer " + newToken.Raw,
ExpiresAt: newToken.ExpiresAt,

View file

@ -0,0 +1,55 @@
package providers
import (
"errors"
"net/http"
"github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/httpkit/server"
"github.com/rs/zerolog/log"
)
type LoginForm struct {
Username string `json:"username"`
Password string `json:"password"`
StayLoggedIn bool `json:"stayLoggedIn"`
}
func getLoginForm(r *http.Request) (LoginForm, error) {
loginForm := LoginForm{}
switch r.Header.Get("Content-Type") {
case "application/x-www-form-urlencoded":
err := r.ParseForm()
if err != nil {
return loginForm, errors.New("failed to parse form")
}
loginForm.Username = r.PostFormValue("username")
loginForm.Password = r.PostFormValue("password")
loginForm.StayLoggedIn = r.PostFormValue("stayLoggedIn") == "true"
case "application/json":
err := server.Decode(r, &loginForm)
if err != nil {
log.Err(err).Msg("failed to decode login form")
return loginForm, errors.New("failed to decode login form")
}
default:
return loginForm, errors.New("invalid content type")
}
if loginForm.Username == "" || loginForm.Password == "" {
return loginForm, validate.NewFieldErrors(
validate.FieldError{
Field: "username",
Error: "username or password is empty",
},
validate.FieldError{
Field: "password",
Error: "username or password is empty",
},
)
}
return loginForm, nil
}

View file

@ -0,0 +1,30 @@
package providers
import (
"net/http"
"github.com/hay-kot/homebox/backend/internal/core/services"
)
type LocalProvider struct {
service *services.UserService
}
func NewLocalProvider(service *services.UserService) *LocalProvider {
return &LocalProvider{
service: service,
}
}
func (p *LocalProvider) Name() string {
return "local"
}
func (p *LocalProvider) Authenticate(w http.ResponseWriter, r *http.Request) (services.UserAuthTokenDetail, error) {
loginForm, err := getLoginForm(r)
if err != nil {
return services.UserAuthTokenDetail{}, err
}
return p.service.Login(r.Context(), loginForm.Username, loginForm.Password, loginForm.StayLoggedIn)
}

View file

@ -12,6 +12,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/hay-kot/homebox/backend/app/api/handlers/debughandlers"
v1 "github.com/hay-kot/homebox/backend/app/api/handlers/v1"
"github.com/hay-kot/homebox/backend/app/api/providers"
_ "github.com/hay-kot/homebox/backend/app/api/static/docs"
"github.com/hay-kot/homebox/backend/internal/data/ent/authroles"
"github.com/hay-kot/homebox/backend/internal/data/repo"
@ -63,8 +64,12 @@ func (a *app) mountRoutes(r *chi.Mux, chain *errchain.ErrChain, repos *repo.AllR
BuildTime: buildTime,
})))
providers := []v1.AuthProvider{
providers.NewLocalProvider(a.services.User),
}
r.Post(v1Base("/users/register"), chain.ToHandlerFunc(v1Ctrl.HandleUserRegistration()))
r.Post(v1Base("/users/login"), chain.ToHandlerFunc(v1Ctrl.HandleAuthLogin()))
r.Post(v1Base("/users/login"), chain.ToHandlerFunc(v1Ctrl.HandleAuthLogin(providers...)))
userMW := []errchain.Middleware{
a.mwAuthToken,