From d97d94537a8caf9cfbb01ec5b659704267650c87 Mon Sep 17 00:00:00 2001 From: Adnan Hajdarevic Date: Tue, 26 Jan 2021 22:19:46 +0100 Subject: [PATCH] Add IsNull and Exists types to the Match rule --- internal/hook/hook.go | 57 +++++++++++++++++++++----------------- internal/hook/hook_test.go | 4 +-- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/internal/hook/hook.go b/internal/hook/hook.go index 0510095..b070ad3 100644 --- a/internal/hook/hook.go +++ b/internal/hook/hook.go @@ -394,7 +394,7 @@ func GetParameter(s string, params interface{}) (interface{}, error) { return v, nil } - // Checked for dotted references + // Check for dotted references p := strings.SplitN(s, ".", 2) if pValue, ok := params.(map[string]interface{})[p[0]]; ok { if len(p) > 1 { @@ -411,23 +411,23 @@ func GetParameter(s string, params interface{}) (interface{}, error) { // ExtractParameterAsString extracts value from interface{} as string based on // the passed string. Complex data types are rendered as JSON instead of the Go // Stringer format. -func ExtractParameterAsString(s string, params interface{}) (string, error) { +func ExtractParameterAsString(s string, params interface{}) (string, interface{}, error) { pValue, err := GetParameter(s, params) if err != nil { - return "", err + return "", nil, err } switch v := reflect.ValueOf(pValue); v.Kind() { case reflect.Array, reflect.Map, reflect.Slice: r, err := json.Marshal(pValue) if err != nil { - return "", err + return "", pValue, err } - return string(r), nil + return string(r), r, nil default: - return fmt.Sprintf("%v", pValue), nil + return fmt.Sprintf("%v", pValue), pValue, nil } } @@ -442,7 +442,7 @@ type Argument struct { // Get Argument method returns the value for the Argument's key name // based on the Argument's source -func (ha *Argument) Get(r *Request) (string, error) { +func (ha *Argument) Get(r *Request) (string, interface{}, error) { var source *map[string]interface{} key := ha.Name @@ -458,55 +458,55 @@ func (ha *Argument) Get(r *Request) (string, error) { source = &r.Payload case SourceString: - return ha.Name, nil + return ha.Name, ha.Name, nil case SourceRawRequestBody: - return string(r.Body), nil + return string(r.Body), r.Body, nil case SourceRequest: if r == nil || r.RawRequest == nil { - return "", errors.New("request is nil") + return "", nil, errors.New("request is nil") } switch strings.ToLower(ha.Name) { case "remote-addr": - return r.RawRequest.RemoteAddr, nil + return r.RawRequest.RemoteAddr, r.RawRequest.RemoteAddr, nil case "method": - return r.RawRequest.Method, nil + return r.RawRequest.Method, r.RawRequest.Method, nil default: - return "", fmt.Errorf("unsupported request key: %q", ha.Name) + return "", nil, fmt.Errorf("unsupported request key: %q", ha.Name) } case SourceEntirePayload: res, err := json.Marshal(&r.Payload) if err != nil { - return "", err + return "", r.Payload, err } - return string(res), nil + return string(res), r.Payload, nil case SourceEntireHeaders: res, err := json.Marshal(&r.Headers) if err != nil { - return "", err + return "", r.Headers, err } - return string(res), nil + return string(res), r.Headers, nil case SourceEntireQuery: res, err := json.Marshal(&r.Query) if err != nil { - return "", err + return "", r.Query, err } - return string(res), nil + return string(res), r.Query, nil } if source != nil { return ExtractParameterAsString(key, *source) } - return "", errors.New("no source for value retrieval") + return "", nil, errors.New("no source for value retrieval") } // Header is a structure containing header name and it's value @@ -589,7 +589,7 @@ func (h *Hook) ParseJSONParameters(r *Request) []error { errors := make([]error, 0) for i := range h.JSONStringParameters { - arg, err := h.JSONStringParameters[i].Get(r) + arg, _, err := h.JSONStringParameters[i].Get(r) if err != nil { errors = append(errors, &ArgumentError{h.JSONStringParameters[i]}) } else { @@ -645,7 +645,7 @@ func (h *Hook) ExtractCommandArguments(r *Request) ([]string, []error) { args = append(args, h.ExecuteCommand) for i := range h.PassArgumentsToCommand { - arg, err := h.PassArgumentsToCommand[i].Get(r) + arg, _, err := h.PassArgumentsToCommand[i].Get(r) if err != nil { args = append(args, "") errors = append(errors, &ArgumentError{h.PassArgumentsToCommand[i]}) @@ -669,7 +669,7 @@ func (h *Hook) ExtractCommandArgumentsForEnv(r *Request) ([]string, []error) { args := make([]string, 0) errors := make([]error, 0) for i := range h.PassEnvironmentToCommand { - arg, err := h.PassEnvironmentToCommand[i].Get(r) + arg, _, err := h.PassEnvironmentToCommand[i].Get(r) if err != nil { errors = append(errors, &ArgumentError{h.PassEnvironmentToCommand[i]}) continue @@ -705,7 +705,7 @@ func (h *Hook) ExtractCommandArgumentsForFile(r *Request) ([]FileParameter, []er args := make([]FileParameter, 0) errors := make([]error, 0) for i := range h.PassFileToCommand { - arg, err := h.PassFileToCommand[i].Get(r) + arg, _, err := h.PassFileToCommand[i].Get(r) if err != nil { errors = append(errors, &ArgumentError{h.PassFileToCommand[i]}) continue @@ -898,6 +898,8 @@ type MatchRule struct { const ( MatchValue string = "value" MatchRegex string = "regex" + MatchIsNull string = "is-null" + MatchExists string = "exists" MatchHMACSHA1 string = "payload-hmac-sha1" MatchHMACSHA256 string = "payload-hmac-sha256" MatchHMACSHA512 string = "payload-hmac-sha512" @@ -917,13 +919,17 @@ func (r MatchRule) Evaluate(req *Request) (bool, error) { return CheckScalrSignature(req, r.Secret, true) } - arg, err := r.Parameter.Get(req) + arg, rawValue, err := r.Parameter.Get(req) if err == nil { switch r.Type { case MatchValue: return compare(arg, r.Value), nil case MatchRegex: return regexp.MatchString(r.Regex, arg) + case MatchIsNull: + return rawValue == nil, nil + case MatchExists: + return true, nil case MatchHashSHA1: log.Print(`warn: use of deprecated option payload-hash-sha1; use payload-hmac-sha1 instead`) fallthrough @@ -944,6 +950,7 @@ func (r MatchRule) Evaluate(req *Request) (bool, error) { return err == nil, err } } + return false, err } diff --git a/internal/hook/hook_test.go b/internal/hook/hook_test.go index e9f29e4..e2ca805 100644 --- a/internal/hook/hook_test.go +++ b/internal/hook/hook_test.go @@ -245,7 +245,7 @@ var extractParameterTests = []struct { func TestExtractParameter(t *testing.T) { for _, tt := range extractParameterTests { - value, err := ExtractParameterAsString(tt.s, tt.params) + value, _, err := ExtractParameterAsString(tt.s, tt.params) if (err == nil) != tt.ok || value != tt.value { t.Errorf("failed to extract parameter %q:\nexpected {value:%#v, ok:%#v},\ngot {value:%#v, err:%v}", tt.s, tt.value, tt.ok, value, err) } @@ -281,7 +281,7 @@ func TestArgumentGet(t *testing.T) { Payload: tt.payload, RawRequest: tt.request, } - value, err := a.Get(r) + value, _, err := a.Get(r) if (err == nil) != tt.ok || value != tt.value { t.Errorf("failed to get {%q, %q}:\nexpected {value:%#v, ok:%#v},\ngot {value:%#v, err:%v}", tt.source, tt.name, tt.value, tt.ok, value, err) }