diff --git a/backend/app/api/middleware.go b/backend/app/api/middleware.go index fb6b9cf..eac0de0 100644 --- a/backend/app/api/middleware.go +++ b/backend/app/api/middleware.go @@ -4,11 +4,10 @@ import ( "context" "errors" "net/http" - "net/url" - "strings" "github.com/hay-kot/homebox/backend/internal/core/services" "github.com/hay-kot/homebox/backend/internal/sys/validate" + "github.com/hay-kot/homebox/backend/pkgs/lookup" "github.com/hay-kot/httpkit/errchain" ) @@ -69,45 +68,6 @@ func (a *app) mwRoles(rm RoleMode, required ...string) errchain.Middleware { } } -type KeyFunc func(r *http.Request) (string, error) - -func getBearer(r *http.Request) (string, error) { - auth := r.Header.Get("Authorization") - if auth == "" { - return "", errors.New("authorization header is required") - } - - return auth, nil -} - -func getQuery(r *http.Request) (string, error) { - token := r.URL.Query().Get("access_token") - if token == "" { - return "", errors.New("access_token query is required") - } - - token, err := url.QueryUnescape(token) - if err != nil { - return "", errors.New("access_token query is required") - } - - return token, nil -} - -func getCookie(r *http.Request) (string, error) { - cookie, err := r.Cookie("hb.auth.token") - if err != nil { - return "", errors.New("access_token cookie is required") - } - - token, err := url.QueryUnescape(cookie.Value) - if err != nil { - return "", errors.New("access_token cookie is required") - } - - return token, nil -} - // mwAuthToken is a middleware that will check the database for a stateful token // and attach it's user to the request context, or return an appropriate error. // Authorization support is by token via Headers or Query Parameter @@ -118,27 +78,16 @@ func getCookie(r *http.Request) (string, error) { // - cookie = hb.auth.token = 1234567890 func (a *app) mwAuthToken(next errchain.Handler) errchain.Handler { return errchain.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - keyFuncs := [...]KeyFunc{ - getBearer, - getCookie, - getQuery, + extractor := lookup.MultiExtractor{ + lookup.HeaderExtractor{Key: "Authorization", Prefix: "Bearer"}, + lookup.ArgumentExtractor("access_token"), + lookup.CookieExtractor("hb.auth.token"), } - - var requestToken string - for _, keyFunc := range keyFuncs { - token, err := keyFunc(r) - if err == nil { - requestToken = token - break - } - } - - if requestToken == "" { + requestToken, err := extractor.ExtractValue(r) + if err != nil || requestToken == "" { return validate.NewRequestError(errors.New("Authorization header or query is required"), http.StatusUnauthorized) } - requestToken = strings.TrimPrefix(requestToken, "Bearer ") - r = r.WithContext(context.WithValue(r.Context(), hashedToken, requestToken)) usr, err := a.services.User.GetSelf(r.Context(), requestToken) diff --git a/backend/pkgs/lookup/extractor.go b/backend/pkgs/lookup/extractor.go new file mode 100644 index 0000000..5b4544f --- /dev/null +++ b/backend/pkgs/lookup/extractor.go @@ -0,0 +1,99 @@ +package lookup + +import ( + "errors" + "net/http" + "net/url" + "strings" +) + +// Extractor is an interface for extracting a value from an HTTP request. +// The ExtractValue method should return a value string or an error. +// If no value is present, you must return ErrMissingValue. +type Extractor interface { + ExtractValue(*http.Request) (string, error) +} + +// MultiExtractor tries Extractors in order until one returns a value string or an error occurs +type MultiExtractor []Extractor + +func (e MultiExtractor) ExtractValue(req *http.Request) (string, error) { + // loop over header names and return the first one that contains data + for _, extractor := range e { + if val, err := extractor.ExtractValue(req); val != "" { + return val, nil + } else if !errors.Is(err, ErrMissingValue) { + return "", err + } + } + return "", ErrMissingValue +} + +// HeaderExtractor is an extractor for finding a value in a header. +// Looks at each specified header in order until there's a match +type HeaderExtractor struct { + // The key of the header + // Required + Key string + // Strips 'Bearer ' prefix from bearer value string. + // Possible value is "Bearer" + // Optional + Prefix string +} + +func (e HeaderExtractor) ExtractValue(r *http.Request) (string, error) { + // loop over header names and return the first one that contains data + return stripHeadValuePrefixFromValueString(e.Prefix)(r.Header.Get(e.Key)) +} + +// ArgumentExtractor extracts a value from request arguments. This includes a POSTed form or +// GET URL arguments. +// This extractor calls `ParseMultipartForm` on the request +type ArgumentExtractor string + +func (e ArgumentExtractor) ExtractValue(r *http.Request) (string, error) { + // Make sure form is parsed + _ = r.ParseMultipartForm(10e6) + + tk := strings.TrimSpace(r.Form.Get(string(e))) + if tk != "" { + return tk, nil + } + return "", ErrMissingValue +} + +// CookieExtractor extracts a value from cookie. +type CookieExtractor string + +func (e CookieExtractor) ExtractValue(r *http.Request) (string, error) { + cookie, err := r.Cookie(string(e)) + if err != nil { + return "", ErrMissingValue + } + val, _ := url.QueryUnescape(cookie.Value) + if val = strings.TrimSpace(val); val != "" { + return val, nil + } + return "", ErrMissingValue +} + +// Strips like 'Bearer ' prefix from value string with header name +func stripHeadValuePrefixFromValueString(prefix string) func(string) (string, error) { + return func(tok string) (string, error) { + tok = strings.TrimSpace(tok) + if tok == "" { + return "", ErrMissingValue + } + l := len(prefix) + if l == 0 { + return tok, nil + } + // Should be a bearer value + if len(tok) > l && strings.EqualFold(tok[:l], prefix) { + if tok = strings.TrimSpace(tok[l+1:]); tok != "" { + return tok, nil + } + } + return "", ErrMissingValue + } +} diff --git a/backend/pkgs/lookup/extractor_test.go b/backend/pkgs/lookup/extractor_test.go new file mode 100644 index 0000000..9d1861f --- /dev/null +++ b/backend/pkgs/lookup/extractor_test.go @@ -0,0 +1,132 @@ +package lookup + +import ( + "fmt" + "net/http" + "net/url" + "testing" +) + +func TestExtractor(t *testing.T) { + var extractorTestValue = "testTokenValue" + + var tests = []struct { + name string + extractor Extractor + headers map[string]string + query url.Values + cookie map[string]string + wantValue string + wantErr error + }{ + { + name: "header hit", + extractor: HeaderExtractor{"token", ""}, + headers: map[string]string{"token": extractorTestValue}, + query: nil, + cookie: nil, + wantValue: extractorTestValue, + wantErr: nil, + }, + { + name: "header miss", + extractor: HeaderExtractor{"This-Header-Is-Not-Set", ""}, + headers: map[string]string{"token": extractorTestValue}, + query: nil, + cookie: nil, + wantValue: "", + wantErr: ErrMissingValue, + }, + + { + name: "header filter", + extractor: HeaderExtractor{"Authorization", "Bearer"}, + headers: map[string]string{"Authorization": "Bearer " + extractorTestValue}, + query: nil, + cookie: nil, + wantValue: extractorTestValue, + wantErr: nil, + }, + { + name: "header filter miss", + extractor: HeaderExtractor{"Authorization", "Bearer"}, + headers: map[string]string{"Authorization": "Bearer "}, + query: nil, + cookie: nil, + wantValue: "", + wantErr: ErrMissingValue, + }, + { + name: "argument hit", + extractor: ArgumentExtractor("token"), + headers: map[string]string{}, + query: url.Values{"token": {extractorTestValue}}, + cookie: nil, + wantValue: extractorTestValue, + wantErr: nil, + }, + { + name: "argument miss", + extractor: ArgumentExtractor("token"), + headers: map[string]string{}, + query: nil, + cookie: nil, + wantValue: "", + wantErr: ErrMissingValue, + }, + { + name: "cookie hit", + extractor: CookieExtractor("token"), + headers: map[string]string{}, + query: nil, + cookie: map[string]string{"token": extractorTestValue}, + wantValue: extractorTestValue, + wantErr: nil, + }, + { + name: "cookie miss", + extractor: ArgumentExtractor("token"), + headers: map[string]string{}, + query: nil, + cookie: map[string]string{}, + wantValue: "", + wantErr: ErrMissingValue, + }, + { + name: "cookie miss", + extractor: ArgumentExtractor("token"), + headers: map[string]string{}, + query: nil, + cookie: map[string]string{"token": " "}, + wantValue: "", + wantErr: ErrMissingValue, + }, + } + // Bearer token request + for _, e := range tests { + // Make request from test struct + r := makeTestRequest("GET", "/", e.headers, e.cookie, e.query) + + // Test extractor + value, err := e.extractor.ExtractValue(r) + if value != e.wantValue { + t.Errorf("[%v] Expected value '%v'. Got '%v'", e.name, e.wantValue, value) + continue + } + if err != e.wantErr { + t.Errorf("[%v] Expected error '%v'. Got '%v'", e.name, e.wantErr, err) + continue + } + } +} + +func makeTestRequest(method, path string, headers, cookie map[string]string, urlArgs url.Values) *http.Request { + r, _ := http.NewRequest(method, fmt.Sprintf("%v?%v", path, urlArgs.Encode()), nil) + for k, v := range headers { + r.Header.Set(k, v) + } + for k, v := range cookie { + r.AddCookie(&http.Cookie{Name: k, Value: v}) + } + return r +} diff --git a/backend/pkgs/lookup/lookup.go b/backend/pkgs/lookup/lookup.go new file mode 100644 index 0000000..d44c862 --- /dev/null +++ b/backend/pkgs/lookup/lookup.go @@ -0,0 +1,85 @@ +package lookup + +import ( + "errors" + "net/http" + "strings" +) + +// ErrMissingValue can be thrown by follow +// if value with a HTTP header, the value header needs to be set +// if value with URL Query, the query value variable is empty +// if value with a cookie, the value cookie is empty +var ErrMissingValue = errors.New("no value present in request") + +// Lookup is a tool that looks up the value from http request, such as token +type Lookup struct { + extractors MultiExtractor +} + +// NewLookup new a lookup. +// lookup is a string in the form of ":[:]" that is used +// to extract value from the request. +// use like "header:[:],query:,cookie:,param:" +// Optional, Default value "header:Authorization:Bearer" for json web token. +// Possible values: +// - "header::", is a special string in the header, Possible value is "Bearer" +// - "query:" +// - "cookie:" +func NewLookup(lookup string) *Lookup { + if lookup == "" { + lookup = "header:Authorization:Bearer" + } + methods := strings.Split(lookup, ",") + lookups := make(MultiExtractor, 0, len(methods)) + for _, method := range methods { + parts := strings.Split(strings.TrimSpace(method), ":") + if !(len(parts) == 2 || len(parts) == 3) { + continue + } + switch parts[0] { + case "header": + prefix := "" + if len(parts) == 3 { + prefix = strings.TrimSpace(parts[2]) + } + lookups = append(lookups, HeaderExtractor{strings.TrimSpace(parts[1]), prefix}) + case "query": + lookups = append(lookups, ArgumentExtractor(parts[1])) + case "cookie": + lookups = append(lookups, CookieExtractor(parts[1])) + } + } + if len(lookups) == 0 { + lookups = append(lookups, HeaderExtractor{"Authorization", "Bearer"}) + } + return &Lookup{lookups} +} + +// ExtractValue extract value from http request. +func (sf *Lookup) ExtractValue(r *http.Request) (string, error) { + value, err := sf.extractors.ExtractValue(r) + if err != nil || value == "" { + return "", ErrMissingValue + } + return value, nil +} + +// FromHeader get value from header +// key is a header key, like "Authorization" +// prefix is a string in the header, like "Bearer", if it is empty, only will return value. +func FromHeader(r *http.Request, key, prefix string) (string, error) { + return HeaderExtractor{key, prefix}.ExtractValue(r) +} + +// FromQuery get value from query +// key is a query key +func FromQuery(r *http.Request, key string) (string, error) { + return ArgumentExtractor(key).ExtractValue(r) +} + +// FromCookie get value from Cookie +// key is a cookie key +func FromCookie(r *http.Request, key string) (string, error) { + return CookieExtractor(key).ExtractValue(r) +} diff --git a/backend/pkgs/lookup/lookup_test.go b/backend/pkgs/lookup/lookup_test.go new file mode 100644 index 0000000..880e2c5 --- /dev/null +++ b/backend/pkgs/lookup/lookup_test.go @@ -0,0 +1,169 @@ +package lookup + +import ( + "net/url" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestLookup(t *testing.T) { + var extractorTestTokenValue = "testTokenValue" + + var tests = []struct { + name string + lookup string + headers map[string]string + query url.Values + cookie map[string]string + wantValue string + wantErr error + }{ + { + name: "invalid lookup but use default", + lookup: "xx", + headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue}, + query: nil, + cookie: nil, + wantValue: extractorTestTokenValue, + wantErr: nil, + }, + { + name: "ignore invalid lookup", + lookup: "header:Authorization:Bearer,xxxx", + headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue}, + query: nil, + cookie: nil, + wantValue: extractorTestTokenValue, + wantErr: nil, + }, + { + name: "header default hit", + lookup: "", + headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue}, + query: nil, + cookie: nil, + wantValue: extractorTestTokenValue, + wantErr: nil, + }, + { + name: "header hit", + lookup: "header:token", + headers: map[string]string{"token": extractorTestTokenValue}, + query: nil, + cookie: nil, + wantValue: extractorTestTokenValue, + wantErr: nil, + }, + { + name: "header miss", + lookup: "header:This-Header-Is-Not-Set", + headers: map[string]string{"token": extractorTestTokenValue}, + query: nil, + cookie: nil, + wantValue: "", + wantErr: ErrMissingValue, + }, + + { + name: "header filter", + lookup: "header:Authorization:Bearer", + headers: map[string]string{"Authorization": "Bearer " + extractorTestTokenValue}, + query: nil, + cookie: nil, + wantValue: extractorTestTokenValue, + wantErr: nil, + }, + { + name: "header filter miss", + lookup: "header:Authorization:Bearer", + headers: map[string]string{"Authorization": "Bearer "}, + query: nil, + cookie: nil, + wantValue: "", + wantErr: ErrMissingValue, + }, + { + name: "argument hit", + lookup: "query:token", + headers: map[string]string{}, + query: url.Values{"token": {extractorTestTokenValue}}, + cookie: nil, + wantValue: extractorTestTokenValue, + wantErr: nil, + }, + { + name: "argument miss", + lookup: "query:token", + headers: map[string]string{}, + query: nil, + cookie: nil, + wantValue: "", + wantErr: ErrMissingValue, + }, + { + name: "cookie hit", + lookup: "cookie:token", + headers: map[string]string{}, + query: nil, + cookie: map[string]string{"token": extractorTestTokenValue}, + wantValue: extractorTestTokenValue, + wantErr: nil, + }, + { + name: "cookie miss", + lookup: "cookie:token", + headers: map[string]string{}, + query: nil, + cookie: map[string]string{}, + wantValue: "", + wantErr: ErrMissingValue, + }, + { + name: "cookie miss", + lookup: "cookie:token", + headers: map[string]string{}, + query: nil, + cookie: map[string]string{"token": " "}, + wantValue: "", + wantErr: ErrMissingValue, + }, + } + // Bearer token request + for _, e := range tests { + // Make request from test struct + r := makeTestRequest("GET", "/", e.headers, e.cookie, e.query) + + // Test extractor + value, err := NewLookup(e.lookup).ExtractValue(r) + if value != e.wantValue { + t.Errorf("[%v] Expected value '%v'. Got '%v'", e.name, e.wantValue, value) + continue + } + if err != e.wantErr { + t.Errorf("[%v] Expected error '%v'. Got '%v'", e.name, e.wantErr, err) + continue + } + } +} + +func TestFrom(t *testing.T) { + t.Run("from header", func(t *testing.T) { + r := makeTestRequest("GET", "/", map[string]string{"token": "foo"}, nil, nil) + tk, err := FromHeader(r, "token", "") + require.NoError(t, err) + require.Equal(t, "foo", tk) + }) + t.Run("from query", func(t *testing.T) { + r := makeTestRequest("GET", "/", nil, nil, url.Values{"token": {"foo"}}) + tk, err := FromQuery(r, "token") + require.NoError(t, err) + require.Equal(t, "foo", tk) + }) + t.Run("from query", func(t *testing.T) { + r := makeTestRequest("GET", "/", nil, map[string]string{"token": "foo"}, nil) + tk, err := FromCookie(r, "token") + require.NoError(t, err) + require.Equal(t, "foo", tk) + }) +}