mirror of
https://github.com/hay-kot/homebox.git
synced 2025-08-02 07:40:28 +00:00
feat: add lookup package to extractor value from query,header,cookie
This commit is contained in:
parent
94fd9c314d
commit
ba234f4bd0
5 changed files with 492 additions and 58 deletions
|
@ -4,11 +4,10 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/hay-kot/homebox/backend/internal/core/services"
|
||||
"github.com/hay-kot/homebox/backend/internal/sys/validate"
|
||||
"github.com/hay-kot/homebox/backend/pkgs/lookup"
|
||||
"github.com/hay-kot/httpkit/errchain"
|
||||
)
|
||||
|
||||
|
@ -69,45 +68,6 @@ func (a *app) mwRoles(rm RoleMode, required ...string) errchain.Middleware {
|
|||
}
|
||||
}
|
||||
|
||||
type KeyFunc func(r *http.Request) (string, error)
|
||||
|
||||
func getBearer(r *http.Request) (string, error) {
|
||||
auth := r.Header.Get("Authorization")
|
||||
if auth == "" {
|
||||
return "", errors.New("authorization header is required")
|
||||
}
|
||||
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func getQuery(r *http.Request) (string, error) {
|
||||
token := r.URL.Query().Get("access_token")
|
||||
if token == "" {
|
||||
return "", errors.New("access_token query is required")
|
||||
}
|
||||
|
||||
token, err := url.QueryUnescape(token)
|
||||
if err != nil {
|
||||
return "", errors.New("access_token query is required")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func getCookie(r *http.Request) (string, error) {
|
||||
cookie, err := r.Cookie("hb.auth.token")
|
||||
if err != nil {
|
||||
return "", errors.New("access_token cookie is required")
|
||||
}
|
||||
|
||||
token, err := url.QueryUnescape(cookie.Value)
|
||||
if err != nil {
|
||||
return "", errors.New("access_token cookie is required")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// mwAuthToken is a middleware that will check the database for a stateful token
|
||||
// and attach it's user to the request context, or return an appropriate error.
|
||||
// Authorization support is by token via Headers or Query Parameter
|
||||
|
@ -118,27 +78,16 @@ func getCookie(r *http.Request) (string, error) {
|
|||
// - cookie = hb.auth.token = 1234567890
|
||||
func (a *app) mwAuthToken(next errchain.Handler) errchain.Handler {
|
||||
return errchain.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
keyFuncs := [...]KeyFunc{
|
||||
getBearer,
|
||||
getCookie,
|
||||
getQuery,
|
||||
extractor := lookup.MultiExtractor{
|
||||
lookup.HeaderExtractor{Key: "Authorization", Prefix: "Bearer"},
|
||||
lookup.ArgumentExtractor("access_token"),
|
||||
lookup.CookieExtractor("hb.auth.token"),
|
||||
}
|
||||
|
||||
var requestToken string
|
||||
for _, keyFunc := range keyFuncs {
|
||||
token, err := keyFunc(r)
|
||||
if err == nil {
|
||||
requestToken = token
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if requestToken == "" {
|
||||
requestToken, err := extractor.ExtractValue(r)
|
||||
if err != nil || requestToken == "" {
|
||||
return validate.NewRequestError(errors.New("Authorization header or query is required"), http.StatusUnauthorized)
|
||||
}
|
||||
|
||||
requestToken = strings.TrimPrefix(requestToken, "Bearer ")
|
||||
|
||||
r = r.WithContext(context.WithValue(r.Context(), hashedToken, requestToken))
|
||||
|
||||
usr, err := a.services.User.GetSelf(r.Context(), requestToken)
|
||||
|
|
99
backend/pkgs/lookup/extractor.go
Normal file
99
backend/pkgs/lookup/extractor.go
Normal file
|
@ -0,0 +1,99 @@
|
|||
package lookup
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Extractor is an interface for extracting a value from an HTTP request.
|
||||
// The ExtractValue method should return a value string or an error.
|
||||
// If no value is present, you must return ErrMissingValue.
|
||||
type Extractor interface {
|
||||
ExtractValue(*http.Request) (string, error)
|
||||
}
|
||||
|
||||
// MultiExtractor tries Extractors in order until one returns a value string or an error occurs
|
||||
type MultiExtractor []Extractor
|
||||
|
||||
func (e MultiExtractor) ExtractValue(req *http.Request) (string, error) {
|
||||
// loop over header names and return the first one that contains data
|
||||
for _, extractor := range e {
|
||||
if val, err := extractor.ExtractValue(req); val != "" {
|
||||
return val, nil
|
||||
} else if !errors.Is(err, ErrMissingValue) {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return "", ErrMissingValue
|
||||
}
|
||||
|
||||
// HeaderExtractor is an extractor for finding a value in a header.
|
||||
// Looks at each specified header in order until there's a match
|
||||
type HeaderExtractor struct {
|
||||
// The key of the header
|
||||
// Required
|
||||
Key string
|
||||
// Strips 'Bearer ' prefix from bearer value string.
|
||||
// Possible value is "Bearer"
|
||||
// Optional
|
||||
Prefix string
|
||||
}
|
||||
|
||||
func (e HeaderExtractor) ExtractValue(r *http.Request) (string, error) {
|
||||
// loop over header names and return the first one that contains data
|
||||
return stripHeadValuePrefixFromValueString(e.Prefix)(r.Header.Get(e.Key))
|
||||
}
|
||||
|
||||
// ArgumentExtractor extracts a value from request arguments. This includes a POSTed form or
|
||||
// GET URL arguments.
|
||||
// This extractor calls `ParseMultipartForm` on the request
|
||||
type ArgumentExtractor string
|
||||
|
||||
func (e ArgumentExtractor) ExtractValue(r *http.Request) (string, error) {
|
||||
// Make sure form is parsed
|
||||
_ = r.ParseMultipartForm(10e6)
|
||||
|
||||
tk := strings.TrimSpace(r.Form.Get(string(e)))
|
||||
if tk != "" {
|
||||
return tk, nil
|
||||
}
|
||||
return "", ErrMissingValue
|
||||
}
|
||||
|
||||
// CookieExtractor extracts a value from cookie.
|
||||
type CookieExtractor string
|
||||
|
||||
func (e CookieExtractor) ExtractValue(r *http.Request) (string, error) {
|
||||
cookie, err := r.Cookie(string(e))
|
||||
if err != nil {
|
||||
return "", ErrMissingValue
|
||||
}
|
||||
val, _ := url.QueryUnescape(cookie.Value)
|
||||
if val = strings.TrimSpace(val); val != "" {
|
||||
return val, nil
|
||||
}
|
||||
return "", ErrMissingValue
|
||||
}
|
||||
|
||||
// Strips like 'Bearer ' prefix from value string with header name
|
||||
func stripHeadValuePrefixFromValueString(prefix string) func(string) (string, error) {
|
||||
return func(tok string) (string, error) {
|
||||
tok = strings.TrimSpace(tok)
|
||||
if tok == "" {
|
||||
return "", ErrMissingValue
|
||||
}
|
||||
l := len(prefix)
|
||||
if l == 0 {
|
||||
return tok, nil
|
||||
}
|
||||
// Should be a bearer value
|
||||
if len(tok) > l && strings.EqualFold(tok[:l], prefix) {
|
||||
if tok = strings.TrimSpace(tok[l+1:]); tok != "" {
|
||||
return tok, nil
|
||||
}
|
||||
}
|
||||
return "", ErrMissingValue
|
||||
}
|
||||
}
|
132
backend/pkgs/lookup/extractor_test.go
Normal file
132
backend/pkgs/lookup/extractor_test.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
package lookup
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractor(t *testing.T) {
|
||||
var extractorTestValue = "testTokenValue"
|
||||
|
||||
var tests = []struct {
|
||||
name string
|
||||
extractor Extractor
|
||||
headers map[string]string
|
||||
query url.Values
|
||||
cookie map[string]string
|
||||
wantValue string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "header hit",
|
||||
extractor: HeaderExtractor{"token", ""},
|
||||
headers: map[string]string{"token": extractorTestValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: extractorTestValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header miss",
|
||||
extractor: HeaderExtractor{"This-Header-Is-Not-Set", ""},
|
||||
headers: map[string]string{"token": extractorTestValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
|
||||
{
|
||||
name: "header filter",
|
||||
extractor: HeaderExtractor{"Authorization", "Bearer"},
|
||||
headers: map[string]string{"Authorization": "Bearer " + extractorTestValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: extractorTestValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header filter miss",
|
||||
extractor: HeaderExtractor{"Authorization", "Bearer"},
|
||||
headers: map[string]string{"Authorization": "Bearer "},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
{
|
||||
name: "argument hit",
|
||||
extractor: ArgumentExtractor("token"),
|
||||
headers: map[string]string{},
|
||||
query: url.Values{"token": {extractorTestValue}},
|
||||
cookie: nil,
|
||||
wantValue: extractorTestValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "argument miss",
|
||||
extractor: ArgumentExtractor("token"),
|
||||
headers: map[string]string{},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
{
|
||||
name: "cookie hit",
|
||||
extractor: CookieExtractor("token"),
|
||||
headers: map[string]string{},
|
||||
query: nil,
|
||||
cookie: map[string]string{"token": extractorTestValue},
|
||||
wantValue: extractorTestValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "cookie miss",
|
||||
extractor: ArgumentExtractor("token"),
|
||||
headers: map[string]string{},
|
||||
query: nil,
|
||||
cookie: map[string]string{},
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
{
|
||||
name: "cookie miss",
|
||||
extractor: ArgumentExtractor("token"),
|
||||
headers: map[string]string{},
|
||||
query: nil,
|
||||
cookie: map[string]string{"token": " "},
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
}
|
||||
// Bearer token request
|
||||
for _, e := range tests {
|
||||
// Make request from test struct
|
||||
r := makeTestRequest("GET", "/", e.headers, e.cookie, e.query)
|
||||
|
||||
// Test extractor
|
||||
value, err := e.extractor.ExtractValue(r)
|
||||
if value != e.wantValue {
|
||||
t.Errorf("[%v] Expected value '%v'. Got '%v'", e.name, e.wantValue, value)
|
||||
continue
|
||||
}
|
||||
if err != e.wantErr {
|
||||
t.Errorf("[%v] Expected error '%v'. Got '%v'", e.name, e.wantErr, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func makeTestRequest(method, path string, headers, cookie map[string]string, urlArgs url.Values) *http.Request {
|
||||
r, _ := http.NewRequest(method, fmt.Sprintf("%v?%v", path, urlArgs.Encode()), nil)
|
||||
for k, v := range headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
for k, v := range cookie {
|
||||
r.AddCookie(&http.Cookie{Name: k, Value: v})
|
||||
}
|
||||
return r
|
||||
}
|
85
backend/pkgs/lookup/lookup.go
Normal file
85
backend/pkgs/lookup/lookup.go
Normal file
|
@ -0,0 +1,85 @@
|
|||
package lookup
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrMissingValue can be thrown by follow
|
||||
// if value with a HTTP header, the value header needs to be set
|
||||
// if value with URL Query, the query value variable is empty
|
||||
// if value with a cookie, the value cookie is empty
|
||||
var ErrMissingValue = errors.New("no value present in request")
|
||||
|
||||
// Lookup is a tool that looks up the value from http request, such as token
|
||||
type Lookup struct {
|
||||
extractors MultiExtractor
|
||||
}
|
||||
|
||||
// NewLookup new a lookup.
|
||||
// lookup is a string in the form of "<source>:<name>[:<prefix>]" that is used
|
||||
// to extract value from the request.
|
||||
// use like "header:<name>[:<prefix>],query:<name>,cookie:<name>,param:<name>"
|
||||
// Optional, Default value "header:Authorization:Bearer" for json web token.
|
||||
// Possible values:
|
||||
// - "header:<name>:<prefix>", <prefix> is a special string in the header, Possible value is "Bearer"
|
||||
// - "query:<name>"
|
||||
// - "cookie:<name>"
|
||||
func NewLookup(lookup string) *Lookup {
|
||||
if lookup == "" {
|
||||
lookup = "header:Authorization:Bearer"
|
||||
}
|
||||
methods := strings.Split(lookup, ",")
|
||||
lookups := make(MultiExtractor, 0, len(methods))
|
||||
for _, method := range methods {
|
||||
parts := strings.Split(strings.TrimSpace(method), ":")
|
||||
if !(len(parts) == 2 || len(parts) == 3) {
|
||||
continue
|
||||
}
|
||||
switch parts[0] {
|
||||
case "header":
|
||||
prefix := ""
|
||||
if len(parts) == 3 {
|
||||
prefix = strings.TrimSpace(parts[2])
|
||||
}
|
||||
lookups = append(lookups, HeaderExtractor{strings.TrimSpace(parts[1]), prefix})
|
||||
case "query":
|
||||
lookups = append(lookups, ArgumentExtractor(parts[1]))
|
||||
case "cookie":
|
||||
lookups = append(lookups, CookieExtractor(parts[1]))
|
||||
}
|
||||
}
|
||||
if len(lookups) == 0 {
|
||||
lookups = append(lookups, HeaderExtractor{"Authorization", "Bearer"})
|
||||
}
|
||||
return &Lookup{lookups}
|
||||
}
|
||||
|
||||
// ExtractValue extract value from http request.
|
||||
func (sf *Lookup) ExtractValue(r *http.Request) (string, error) {
|
||||
value, err := sf.extractors.ExtractValue(r)
|
||||
if err != nil || value == "" {
|
||||
return "", ErrMissingValue
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// FromHeader get value from header
|
||||
// key is a header key, like "Authorization"
|
||||
// prefix is a string in the header, like "Bearer", if it is empty, only will return value.
|
||||
func FromHeader(r *http.Request, key, prefix string) (string, error) {
|
||||
return HeaderExtractor{key, prefix}.ExtractValue(r)
|
||||
}
|
||||
|
||||
// FromQuery get value from query
|
||||
// key is a query key
|
||||
func FromQuery(r *http.Request, key string) (string, error) {
|
||||
return ArgumentExtractor(key).ExtractValue(r)
|
||||
}
|
||||
|
||||
// FromCookie get value from Cookie
|
||||
// key is a cookie key
|
||||
func FromCookie(r *http.Request, key string) (string, error) {
|
||||
return CookieExtractor(key).ExtractValue(r)
|
||||
}
|
169
backend/pkgs/lookup/lookup_test.go
Normal file
169
backend/pkgs/lookup/lookup_test.go
Normal file
|
@ -0,0 +1,169 @@
|
|||
package lookup
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLookup(t *testing.T) {
|
||||
var extractorTestTokenValue = "testTokenValue"
|
||||
|
||||
var tests = []struct {
|
||||
name string
|
||||
lookup string
|
||||
headers map[string]string
|
||||
query url.Values
|
||||
cookie map[string]string
|
||||
wantValue string
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "invalid lookup but use default",
|
||||
lookup: "xx",
|
||||
headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: extractorTestTokenValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "ignore invalid lookup",
|
||||
lookup: "header:Authorization:Bearer,xxxx",
|
||||
headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: extractorTestTokenValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header default hit",
|
||||
lookup: "",
|
||||
headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: extractorTestTokenValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header hit",
|
||||
lookup: "header:token",
|
||||
headers: map[string]string{"token": extractorTestTokenValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: extractorTestTokenValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header miss",
|
||||
lookup: "header:This-Header-Is-Not-Set",
|
||||
headers: map[string]string{"token": extractorTestTokenValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
|
||||
{
|
||||
name: "header filter",
|
||||
lookup: "header:Authorization:Bearer",
|
||||
headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: extractorTestTokenValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "header filter miss",
|
||||
lookup: "header:Authorization:Bearer",
|
||||
headers: map[string]string{"Authorization": "Bearer "},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
{
|
||||
name: "argument hit",
|
||||
lookup: "query:token",
|
||||
headers: map[string]string{},
|
||||
query: url.Values{"token": {extractorTestTokenValue}},
|
||||
cookie: nil,
|
||||
wantValue: extractorTestTokenValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "argument miss",
|
||||
lookup: "query:token",
|
||||
headers: map[string]string{},
|
||||
query: nil,
|
||||
cookie: nil,
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
{
|
||||
name: "cookie hit",
|
||||
lookup: "cookie:token",
|
||||
headers: map[string]string{},
|
||||
query: nil,
|
||||
cookie: map[string]string{"token": extractorTestTokenValue},
|
||||
wantValue: extractorTestTokenValue,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "cookie miss",
|
||||
lookup: "cookie:token",
|
||||
headers: map[string]string{},
|
||||
query: nil,
|
||||
cookie: map[string]string{},
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
{
|
||||
name: "cookie miss",
|
||||
lookup: "cookie:token",
|
||||
headers: map[string]string{},
|
||||
query: nil,
|
||||
cookie: map[string]string{"token": " "},
|
||||
wantValue: "",
|
||||
wantErr: ErrMissingValue,
|
||||
},
|
||||
}
|
||||
// Bearer token request
|
||||
for _, e := range tests {
|
||||
// Make request from test struct
|
||||
r := makeTestRequest("GET", "/", e.headers, e.cookie, e.query)
|
||||
|
||||
// Test extractor
|
||||
value, err := NewLookup(e.lookup).ExtractValue(r)
|
||||
if value != e.wantValue {
|
||||
t.Errorf("[%v] Expected value '%v'. Got '%v'", e.name, e.wantValue, value)
|
||||
continue
|
||||
}
|
||||
if err != e.wantErr {
|
||||
t.Errorf("[%v] Expected error '%v'. Got '%v'", e.name, e.wantErr, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFrom(t *testing.T) {
|
||||
t.Run("from header", func(t *testing.T) {
|
||||
r := makeTestRequest("GET", "/", map[string]string{"token": "foo"}, nil, nil)
|
||||
tk, err := FromHeader(r, "token", "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foo", tk)
|
||||
})
|
||||
t.Run("from query", func(t *testing.T) {
|
||||
r := makeTestRequest("GET", "/", nil, nil, url.Values{"token": {"foo"}})
|
||||
tk, err := FromQuery(r, "token")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foo", tk)
|
||||
})
|
||||
t.Run("from query", func(t *testing.T) {
|
||||
r := makeTestRequest("GET", "/", nil, map[string]string{"token": "foo"}, nil)
|
||||
tk, err := FromCookie(r, "token")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "foo", tk)
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue