feat: add lookup package to extractor value from query,header,cookie

This commit is contained in:
thinkgo 2023-09-12 09:22:07 +08:00
parent 94fd9c314d
commit ba234f4bd0
5 changed files with 492 additions and 58 deletions

View file

@ -4,11 +4,10 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"net/url"
"strings"
"github.com/hay-kot/homebox/backend/internal/core/services" "github.com/hay-kot/homebox/backend/internal/core/services"
"github.com/hay-kot/homebox/backend/internal/sys/validate" "github.com/hay-kot/homebox/backend/internal/sys/validate"
"github.com/hay-kot/homebox/backend/pkgs/lookup"
"github.com/hay-kot/httpkit/errchain" "github.com/hay-kot/httpkit/errchain"
) )
@ -69,45 +68,6 @@ func (a *app) mwRoles(rm RoleMode, required ...string) errchain.Middleware {
} }
} }
type KeyFunc func(r *http.Request) (string, error)
func getBearer(r *http.Request) (string, error) {
auth := r.Header.Get("Authorization")
if auth == "" {
return "", errors.New("authorization header is required")
}
return auth, nil
}
func getQuery(r *http.Request) (string, error) {
token := r.URL.Query().Get("access_token")
if token == "" {
return "", errors.New("access_token query is required")
}
token, err := url.QueryUnescape(token)
if err != nil {
return "", errors.New("access_token query is required")
}
return token, nil
}
func getCookie(r *http.Request) (string, error) {
cookie, err := r.Cookie("hb.auth.token")
if err != nil {
return "", errors.New("access_token cookie is required")
}
token, err := url.QueryUnescape(cookie.Value)
if err != nil {
return "", errors.New("access_token cookie is required")
}
return token, nil
}
// mwAuthToken is a middleware that will check the database for a stateful token // mwAuthToken is a middleware that will check the database for a stateful token
// and attach it's user to the request context, or return an appropriate error. // and attach it's user to the request context, or return an appropriate error.
// Authorization support is by token via Headers or Query Parameter // Authorization support is by token via Headers or Query Parameter
@ -118,27 +78,16 @@ func getCookie(r *http.Request) (string, error) {
// - cookie = hb.auth.token = 1234567890 // - cookie = hb.auth.token = 1234567890
func (a *app) mwAuthToken(next errchain.Handler) errchain.Handler { func (a *app) mwAuthToken(next errchain.Handler) errchain.Handler {
return errchain.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return errchain.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
keyFuncs := [...]KeyFunc{ extractor := lookup.MultiExtractor{
getBearer, lookup.HeaderExtractor{Key: "Authorization", Prefix: "Bearer"},
getCookie, lookup.ArgumentExtractor("access_token"),
getQuery, lookup.CookieExtractor("hb.auth.token"),
} }
requestToken, err := extractor.ExtractValue(r)
var requestToken string if err != nil || requestToken == "" {
for _, keyFunc := range keyFuncs {
token, err := keyFunc(r)
if err == nil {
requestToken = token
break
}
}
if requestToken == "" {
return validate.NewRequestError(errors.New("Authorization header or query is required"), http.StatusUnauthorized) return validate.NewRequestError(errors.New("Authorization header or query is required"), http.StatusUnauthorized)
} }
requestToken = strings.TrimPrefix(requestToken, "Bearer ")
r = r.WithContext(context.WithValue(r.Context(), hashedToken, requestToken)) r = r.WithContext(context.WithValue(r.Context(), hashedToken, requestToken))
usr, err := a.services.User.GetSelf(r.Context(), requestToken) usr, err := a.services.User.GetSelf(r.Context(), requestToken)

View file

@ -0,0 +1,99 @@
package lookup
import (
"errors"
"net/http"
"net/url"
"strings"
)
// Extractor is an interface for extracting a value from an HTTP request.
// The ExtractValue method should return a value string or an error.
// If no value is present, you must return ErrMissingValue.
type Extractor interface {
ExtractValue(*http.Request) (string, error)
}
// MultiExtractor tries Extractors in order until one returns a value string or an error occurs
type MultiExtractor []Extractor
func (e MultiExtractor) ExtractValue(req *http.Request) (string, error) {
// loop over header names and return the first one that contains data
for _, extractor := range e {
if val, err := extractor.ExtractValue(req); val != "" {
return val, nil
} else if !errors.Is(err, ErrMissingValue) {
return "", err
}
}
return "", ErrMissingValue
}
// HeaderExtractor is an extractor for finding a value in a header.
// Looks at each specified header in order until there's a match
type HeaderExtractor struct {
// The key of the header
// Required
Key string
// Strips 'Bearer ' prefix from bearer value string.
// Possible value is "Bearer"
// Optional
Prefix string
}
func (e HeaderExtractor) ExtractValue(r *http.Request) (string, error) {
// loop over header names and return the first one that contains data
return stripHeadValuePrefixFromValueString(e.Prefix)(r.Header.Get(e.Key))
}
// ArgumentExtractor extracts a value from request arguments. This includes a POSTed form or
// GET URL arguments.
// This extractor calls `ParseMultipartForm` on the request
type ArgumentExtractor string
func (e ArgumentExtractor) ExtractValue(r *http.Request) (string, error) {
// Make sure form is parsed
_ = r.ParseMultipartForm(10e6)
tk := strings.TrimSpace(r.Form.Get(string(e)))
if tk != "" {
return tk, nil
}
return "", ErrMissingValue
}
// CookieExtractor extracts a value from cookie.
type CookieExtractor string
func (e CookieExtractor) ExtractValue(r *http.Request) (string, error) {
cookie, err := r.Cookie(string(e))
if err != nil {
return "", ErrMissingValue
}
val, _ := url.QueryUnescape(cookie.Value)
if val = strings.TrimSpace(val); val != "" {
return val, nil
}
return "", ErrMissingValue
}
// Strips like 'Bearer ' prefix from value string with header name
func stripHeadValuePrefixFromValueString(prefix string) func(string) (string, error) {
return func(tok string) (string, error) {
tok = strings.TrimSpace(tok)
if tok == "" {
return "", ErrMissingValue
}
l := len(prefix)
if l == 0 {
return tok, nil
}
// Should be a bearer value
if len(tok) > l && strings.EqualFold(tok[:l], prefix) {
if tok = strings.TrimSpace(tok[l+1:]); tok != "" {
return tok, nil
}
}
return "", ErrMissingValue
}
}

View file

@ -0,0 +1,132 @@
package lookup
import (
"fmt"
"net/http"
"net/url"
"testing"
)
func TestExtractor(t *testing.T) {
var extractorTestValue = "testTokenValue"
var tests = []struct {
name string
extractor Extractor
headers map[string]string
query url.Values
cookie map[string]string
wantValue string
wantErr error
}{
{
name: "header hit",
extractor: HeaderExtractor{"token", ""},
headers: map[string]string{"token": extractorTestValue},
query: nil,
cookie: nil,
wantValue: extractorTestValue,
wantErr: nil,
},
{
name: "header miss",
extractor: HeaderExtractor{"This-Header-Is-Not-Set", ""},
headers: map[string]string{"token": extractorTestValue},
query: nil,
cookie: nil,
wantValue: "",
wantErr: ErrMissingValue,
},
{
name: "header filter",
extractor: HeaderExtractor{"Authorization", "Bearer"},
headers: map[string]string{"Authorization": "Bearer " + extractorTestValue},
query: nil,
cookie: nil,
wantValue: extractorTestValue,
wantErr: nil,
},
{
name: "header filter miss",
extractor: HeaderExtractor{"Authorization", "Bearer"},
headers: map[string]string{"Authorization": "Bearer "},
query: nil,
cookie: nil,
wantValue: "",
wantErr: ErrMissingValue,
},
{
name: "argument hit",
extractor: ArgumentExtractor("token"),
headers: map[string]string{},
query: url.Values{"token": {extractorTestValue}},
cookie: nil,
wantValue: extractorTestValue,
wantErr: nil,
},
{
name: "argument miss",
extractor: ArgumentExtractor("token"),
headers: map[string]string{},
query: nil,
cookie: nil,
wantValue: "",
wantErr: ErrMissingValue,
},
{
name: "cookie hit",
extractor: CookieExtractor("token"),
headers: map[string]string{},
query: nil,
cookie: map[string]string{"token": extractorTestValue},
wantValue: extractorTestValue,
wantErr: nil,
},
{
name: "cookie miss",
extractor: ArgumentExtractor("token"),
headers: map[string]string{},
query: nil,
cookie: map[string]string{},
wantValue: "",
wantErr: ErrMissingValue,
},
{
name: "cookie miss",
extractor: ArgumentExtractor("token"),
headers: map[string]string{},
query: nil,
cookie: map[string]string{"token": " "},
wantValue: "",
wantErr: ErrMissingValue,
},
}
// Bearer token request
for _, e := range tests {
// Make request from test struct
r := makeTestRequest("GET", "/", e.headers, e.cookie, e.query)
// Test extractor
value, err := e.extractor.ExtractValue(r)
if value != e.wantValue {
t.Errorf("[%v] Expected value '%v'. Got '%v'", e.name, e.wantValue, value)
continue
}
if err != e.wantErr {
t.Errorf("[%v] Expected error '%v'. Got '%v'", e.name, e.wantErr, err)
continue
}
}
}
func makeTestRequest(method, path string, headers, cookie map[string]string, urlArgs url.Values) *http.Request {
r, _ := http.NewRequest(method, fmt.Sprintf("%v?%v", path, urlArgs.Encode()), nil)
for k, v := range headers {
r.Header.Set(k, v)
}
for k, v := range cookie {
r.AddCookie(&http.Cookie{Name: k, Value: v})
}
return r
}

View file

@ -0,0 +1,85 @@
package lookup
import (
"errors"
"net/http"
"strings"
)
// ErrMissingValue can be thrown by follow
// if value with a HTTP header, the value header needs to be set
// if value with URL Query, the query value variable is empty
// if value with a cookie, the value cookie is empty
var ErrMissingValue = errors.New("no value present in request")
// Lookup is a tool that looks up the value from http request, such as token
type Lookup struct {
extractors MultiExtractor
}
// NewLookup new a lookup.
// lookup is a string in the form of "<source>:<name>[:<prefix>]" that is used
// to extract value from the request.
// use like "header:<name>[:<prefix>],query:<name>,cookie:<name>,param:<name>"
// Optional, Default value "header:Authorization:Bearer" for json web token.
// Possible values:
// - "header:<name>:<prefix>", <prefix> is a special string in the header, Possible value is "Bearer"
// - "query:<name>"
// - "cookie:<name>"
func NewLookup(lookup string) *Lookup {
if lookup == "" {
lookup = "header:Authorization:Bearer"
}
methods := strings.Split(lookup, ",")
lookups := make(MultiExtractor, 0, len(methods))
for _, method := range methods {
parts := strings.Split(strings.TrimSpace(method), ":")
if !(len(parts) == 2 || len(parts) == 3) {
continue
}
switch parts[0] {
case "header":
prefix := ""
if len(parts) == 3 {
prefix = strings.TrimSpace(parts[2])
}
lookups = append(lookups, HeaderExtractor{strings.TrimSpace(parts[1]), prefix})
case "query":
lookups = append(lookups, ArgumentExtractor(parts[1]))
case "cookie":
lookups = append(lookups, CookieExtractor(parts[1]))
}
}
if len(lookups) == 0 {
lookups = append(lookups, HeaderExtractor{"Authorization", "Bearer"})
}
return &Lookup{lookups}
}
// ExtractValue extract value from http request.
func (sf *Lookup) ExtractValue(r *http.Request) (string, error) {
value, err := sf.extractors.ExtractValue(r)
if err != nil || value == "" {
return "", ErrMissingValue
}
return value, nil
}
// FromHeader get value from header
// key is a header key, like "Authorization"
// prefix is a string in the header, like "Bearer", if it is empty, only will return value.
func FromHeader(r *http.Request, key, prefix string) (string, error) {
return HeaderExtractor{key, prefix}.ExtractValue(r)
}
// FromQuery get value from query
// key is a query key
func FromQuery(r *http.Request, key string) (string, error) {
return ArgumentExtractor(key).ExtractValue(r)
}
// FromCookie get value from Cookie
// key is a cookie key
func FromCookie(r *http.Request, key string) (string, error) {
return CookieExtractor(key).ExtractValue(r)
}

View file

@ -0,0 +1,169 @@
package lookup
import (
"net/url"
"testing"
"github.com/stretchr/testify/require"
)
func TestLookup(t *testing.T) {
var extractorTestTokenValue = "testTokenValue"
var tests = []struct {
name string
lookup string
headers map[string]string
query url.Values
cookie map[string]string
wantValue string
wantErr error
}{
{
name: "invalid lookup but use default",
lookup: "xx",
headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue},
query: nil,
cookie: nil,
wantValue: extractorTestTokenValue,
wantErr: nil,
},
{
name: "ignore invalid lookup",
lookup: "header:Authorization:Bearer,xxxx",
headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue},
query: nil,
cookie: nil,
wantValue: extractorTestTokenValue,
wantErr: nil,
},
{
name: "header default hit",
lookup: "",
headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue},
query: nil,
cookie: nil,
wantValue: extractorTestTokenValue,
wantErr: nil,
},
{
name: "header hit",
lookup: "header:token",
headers: map[string]string{"token": extractorTestTokenValue},
query: nil,
cookie: nil,
wantValue: extractorTestTokenValue,
wantErr: nil,
},
{
name: "header miss",
lookup: "header:This-Header-Is-Not-Set",
headers: map[string]string{"token": extractorTestTokenValue},
query: nil,
cookie: nil,
wantValue: "",
wantErr: ErrMissingValue,
},
{
name: "header filter",
lookup: "header:Authorization:Bearer",
headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue},
query: nil,
cookie: nil,
wantValue: extractorTestTokenValue,
wantErr: nil,
},
{
name: "header filter miss",
lookup: "header:Authorization:Bearer",
headers: map[string]string{"Authorization": "Bearer "},
query: nil,
cookie: nil,
wantValue: "",
wantErr: ErrMissingValue,
},
{
name: "argument hit",
lookup: "query:token",
headers: map[string]string{},
query: url.Values{"token": {extractorTestTokenValue}},
cookie: nil,
wantValue: extractorTestTokenValue,
wantErr: nil,
},
{
name: "argument miss",
lookup: "query:token",
headers: map[string]string{},
query: nil,
cookie: nil,
wantValue: "",
wantErr: ErrMissingValue,
},
{
name: "cookie hit",
lookup: "cookie:token",
headers: map[string]string{},
query: nil,
cookie: map[string]string{"token": extractorTestTokenValue},
wantValue: extractorTestTokenValue,
wantErr: nil,
},
{
name: "cookie miss",
lookup: "cookie:token",
headers: map[string]string{},
query: nil,
cookie: map[string]string{},
wantValue: "",
wantErr: ErrMissingValue,
},
{
name: "cookie miss",
lookup: "cookie:token",
headers: map[string]string{},
query: nil,
cookie: map[string]string{"token": " "},
wantValue: "",
wantErr: ErrMissingValue,
},
}
// Bearer token request
for _, e := range tests {
// Make request from test struct
r := makeTestRequest("GET", "/", e.headers, e.cookie, e.query)
// Test extractor
value, err := NewLookup(e.lookup).ExtractValue(r)
if value != e.wantValue {
t.Errorf("[%v] Expected value '%v'. Got '%v'", e.name, e.wantValue, value)
continue
}
if err != e.wantErr {
t.Errorf("[%v] Expected error '%v'. Got '%v'", e.name, e.wantErr, err)
continue
}
}
}
func TestFrom(t *testing.T) {
t.Run("from header", func(t *testing.T) {
r := makeTestRequest("GET", "/", map[string]string{"token": "foo"}, nil, nil)
tk, err := FromHeader(r, "token", "")
require.NoError(t, err)
require.Equal(t, "foo", tk)
})
t.Run("from query", func(t *testing.T) {
r := makeTestRequest("GET", "/", nil, nil, url.Values{"token": {"foo"}})
tk, err := FromQuery(r, "token")
require.NoError(t, err)
require.Equal(t, "foo", tk)
})
t.Run("from query", func(t *testing.T) {
r := makeTestRequest("GET", "/", nil, map[string]string{"token": "foo"}, nil)
tk, err := FromCookie(r, "token")
require.NoError(t, err)
require.Equal(t, "foo", tk)
})
}