diff --git a/backend/app/api/middleware.go b/backend/app/api/middleware.go index b41c913..c694618 100644 --- a/backend/app/api/middleware.go +++ b/backend/app/api/middleware.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "net/url" "strings" "github.com/hay-kot/homebox/backend/internal/core/services" @@ -68,6 +69,45 @@ func (a *app) mwRoles(rm RoleMode, required ...string) server.Middleware { } } +type KeyFunc func(r *http.Request) (string, error) + +func getBearer(r *http.Request) (string, error) { + auth := r.Header.Get("Authorization") + if auth == "" { + return "", errors.New("authorization header is required") + } + + return auth, nil +} + +func getQuery(r *http.Request) (string, error) { + token := r.URL.Query().Get("access_token") + if token == "" { + return "", errors.New("access_token query is required") + } + + token, err := url.QueryUnescape(token) + if err != nil { + return "", errors.New("access_token query is required") + } + + return token, nil +} + +func getCookie(r *http.Request) (string, error) { + cookie, err := r.Cookie("hb.auth.token") + if err != nil { + return "", errors.New("access_token cookie is required") + } + + token, err := url.QueryUnescape(cookie.Value) + if err != nil { + return "", errors.New("access_token cookie is required") + } + + return token, nil +} + // mwAuthToken is a middleware that will check the database for a stateful token // and attach it's user to the request context, or return an appropriate error. // Authorization support is by token via Headers or Query Parameter @@ -75,26 +115,36 @@ func (a *app) mwRoles(rm RoleMode, required ...string) server.Middleware { // Example: // - header = "Bearer 1234567890" // - query = "?access_token=1234567890" +// - cookie = hb.auth.token = 1234567890 func (a *app) mwAuthToken(next server.Handler) server.Handler { return server.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - requestToken := r.Header.Get("Authorization") - if requestToken == "" { - // check for query param - requestToken = r.URL.Query().Get("access_token") - if requestToken == "" { - return validate.NewRequestError(errors.New("Authorization header or query is required"), http.StatusUnauthorized) + keyFuncs := [...]KeyFunc{ + getBearer, + getCookie, + getQuery, + } + + var requestToken string + for _, keyFunc := range keyFuncs { + token, err := keyFunc(r) + if err == nil { + requestToken = token + break } } + if requestToken == "" { + return validate.NewRequestError(errors.New("Authorization header or query is required"), http.StatusUnauthorized) + } + requestToken = strings.TrimPrefix(requestToken, "Bearer ") r = r.WithContext(context.WithValue(r.Context(), hashedToken, requestToken)) usr, err := a.services.User.GetSelf(r.Context(), requestToken) - // Check the database for the token if err != nil { - return validate.NewRequestError(errors.New("Authorization header is required"), http.StatusUnauthorized) + return validate.NewRequestError(errors.New("valid authorization header is required"), http.StatusUnauthorized) } r = r.WithContext(services.SetUserCtx(r.Context(), &usr, requestToken)) diff --git a/frontend/components/App/Header.vue b/frontend/components/App/Header.vue index 1561c87..a142f8d 100644 --- a/frontend/components/App/Header.vue +++ b/frontend/components/App/Header.vue @@ -1,11 +1,9 @@