diff --git a/hook/hook.go b/hook/hook.go index 98fb975..935afc5 100644 --- a/hook/hook.go +++ b/hook/hook.go @@ -32,6 +32,7 @@ const ( SourceQuery string = "url" SourceQueryAlias string = "query" SourcePayload string = "payload" + SourceContext string = "context" SourceString string = "string" SourceEntirePayload string = "entire-payload" SourceEntireQuery string = "entire-query" @@ -323,7 +324,7 @@ type Argument struct { // Get Argument method returns the value for the Argument's key name // based on the Argument's source -func (ha *Argument) Get(headers, query, payload *map[string]interface{}) (string, bool) { +func (ha *Argument) Get(headers, query, payload *map[string]interface{}, context *map[string]interface{}) (string, bool) { var source *map[string]interface{} key := ha.Name @@ -335,6 +336,8 @@ func (ha *Argument) Get(headers, query, payload *map[string]interface{}) (string source = query case SourcePayload: source = payload + case SourceContext: + source = context case SourceString: return ha.Name, true case SourceEntirePayload: @@ -424,6 +427,7 @@ func (h *HooksFiles) Set(value string) error { type Hook struct { ID string `json:"id,omitempty"` ExecuteCommand string `json:"execute-command,omitempty"` + ContextProviderCommand string `json:"context-provider-command,omitempty"` CommandWorkingDirectory string `json:"command-working-directory,omitempty"` ResponseMessage string `json:"response-message,omitempty"` ResponseHeaders ResponseHeaders `json:"response-headers,omitempty"` @@ -441,11 +445,11 @@ type Hook struct { // ParseJSONParameters decodes specified arguments to JSON objects and replaces the // string with the newly created object -func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface{}) []error { +func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface{}, context *map[string]interface{}) []error { errors := make([]error, 0) for i := range h.JSONStringParameters { - if arg, ok := h.JSONStringParameters[i].Get(headers, query, payload); ok { + if arg, ok := h.JSONStringParameters[i].Get(headers, query, payload, context); ok { var newArg map[string]interface{} decoder := json.NewDecoder(strings.NewReader(string(arg))) @@ -464,6 +468,8 @@ func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface source = headers case SourcePayload: source = payload + case SourceContext: + source = context case SourceQuery, SourceQueryAlias: source = query } @@ -493,14 +499,14 @@ func (h *Hook) ParseJSONParameters(headers, query, payload *map[string]interface // ExtractCommandArguments creates a list of arguments, based on the // PassArgumentsToCommand property that is ready to be used with exec.Command() -func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]interface{}) ([]string, []error) { +func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]string, []error) { args := make([]string, 0) errors := make([]error, 0) args = append(args, h.ExecuteCommand) for i := range h.PassArgumentsToCommand { - if arg, ok := h.PassArgumentsToCommand[i].Get(headers, query, payload); ok { + if arg, ok := h.PassArgumentsToCommand[i].Get(headers, query, payload, context); ok { args = append(args, arg) } else { args = append(args, "") @@ -518,11 +524,11 @@ func (h *Hook) ExtractCommandArguments(headers, query, payload *map[string]inter // ExtractCommandArgumentsForEnv creates a list of arguments in key=value // format, based on the PassEnvironmentToCommand property that is ready to be used // with exec.Command(). -func (h *Hook) ExtractCommandArgumentsForEnv(headers, query, payload *map[string]interface{}) ([]string, []error) { +func (h *Hook) ExtractCommandArgumentsForEnv(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]string, []error) { args := make([]string, 0) errors := make([]error, 0) for i := range h.PassEnvironmentToCommand { - if arg, ok := h.PassEnvironmentToCommand[i].Get(headers, query, payload); ok { + if arg, ok := h.PassEnvironmentToCommand[i].Get(headers, query, payload, context); ok { if h.PassEnvironmentToCommand[i].EnvName != "" { // first try to use the EnvName if specified args = append(args, h.PassEnvironmentToCommand[i].EnvName+"="+arg) @@ -552,11 +558,11 @@ type FileParameter struct { // ExtractCommandArgumentsForFile creates a list of arguments in key=value // format, based on the PassFileToCommand property that is ready to be used // with exec.Command(). -func (h *Hook) ExtractCommandArgumentsForFile(headers, query, payload *map[string]interface{}) ([]FileParameter, []error) { +func (h *Hook) ExtractCommandArgumentsForFile(headers, query, payload *map[string]interface{}, context *map[string]interface{}) ([]FileParameter, []error) { args := make([]FileParameter, 0) errors := make([]error, 0) for i := range h.PassFileToCommand { - if arg, ok := h.PassFileToCommand[i].Get(headers, query, payload); ok { + if arg, ok := h.PassFileToCommand[i].Get(headers, query, payload, context); ok { if h.PassFileToCommand[i].EnvName == "" { // if no environment-variable name is set, fall-back on the name @@ -664,16 +670,16 @@ type Rules struct { // Evaluate finds the first rule property that is not nil and returns the value // it evaluates to -func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { +func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { switch { case r.And != nil: - return r.And.Evaluate(headers, query, payload, body, remoteAddr) + return r.And.Evaluate(headers, query, payload, context, body, remoteAddr) case r.Or != nil: - return r.Or.Evaluate(headers, query, payload, body, remoteAddr) + return r.Or.Evaluate(headers, query, payload, context, body, remoteAddr) case r.Not != nil: - return r.Not.Evaluate(headers, query, payload, body, remoteAddr) + return r.Not.Evaluate(headers, query, payload, context, body, remoteAddr) case r.Match != nil: - return r.Match.Evaluate(headers, query, payload, body, remoteAddr) + return r.Match.Evaluate(headers, query, payload, context, body, remoteAddr) } return false, nil @@ -683,11 +689,11 @@ func (r Rules) Evaluate(headers, query, payload *map[string]interface{}, body *[ type AndRule []Rules // Evaluate AndRule will return true if and only if all of ChildRules evaluate to true -func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { +func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { res := true for _, v := range r { - rv, err := v.Evaluate(headers, query, payload, body, remoteAddr) + rv, err := v.Evaluate(headers, query, payload, context, body, remoteAddr) if err != nil { return false, err } @@ -705,11 +711,11 @@ func (r AndRule) Evaluate(headers, query, payload *map[string]interface{}, body type OrRule []Rules // Evaluate OrRule will return true if any of ChildRules evaluate to true -func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { +func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { res := false for _, v := range r { - rv, err := v.Evaluate(headers, query, payload, body, remoteAddr) + rv, err := v.Evaluate(headers, query, payload, context, body, remoteAddr) if err != nil { return false, err } @@ -727,8 +733,8 @@ func (r OrRule) Evaluate(headers, query, payload *map[string]interface{}, body * type NotRule Rules // Evaluate NotRule will return true if and only if ChildRule evaluates to false -func (r NotRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { - rv, err := Rules(r).Evaluate(headers, query, payload, body, remoteAddr) +func (r NotRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { + rv, err := Rules(r).Evaluate(headers, query, payload, context, body, remoteAddr) return !rv, err } @@ -753,15 +759,16 @@ const ( ) // Evaluate MatchRule will return based on the type -func (r MatchRule) Evaluate(headers, query, payload *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { +func (r MatchRule) Evaluate(headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte, remoteAddr string) (bool, error) { if r.Type == IPWhitelist { return CheckIPWhitelist(remoteAddr, r.IPRange) } + if r.Type == ScalrSignature { return CheckScalrSignature(*headers, *body, r.Secret, true) } - if arg, ok := r.Parameter.Get(headers, query, payload); ok { + if arg, ok := r.Parameter.Get(headers, query, payload, context); ok { switch r.Type { case MatchValue: return arg == r.Value, nil diff --git a/webhook.go b/webhook.go index 16cbcbe..f5a9cba 100644 --- a/webhook.go +++ b/webhook.go @@ -4,6 +4,7 @@ import ( "encoding/json" "flag" "fmt" + "io" "io/ioutil" "log" "net/http" @@ -205,7 +206,6 @@ func main() { } func hookHandler(w http.ResponseWriter, r *http.Request) { - // generate a request id for logging rid := uuid.NewV4().String()[:6] @@ -231,8 +231,87 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { // parse query variables query := valuesToMap(r.URL.Query()) - // parse body - var payload map[string]interface{} + // parse context + var context map[string]interface{} + + if matchedHook.ContextProviderCommand != "" { + // check the command exists + cmdPath, err := exec.LookPath(matchedHook.ContextProviderCommand) + if err != nil { + // give a last chance, maybe it's a relative path + relativeToCwd := filepath.Join(matchedHook.CommandWorkingDirectory, matchedHook.ContextProviderCommand) + // check the command exists + cmdPath, err = exec.LookPath(relativeToCwd) + } + + if err != nil { + log.Printf("[%s] unable to locate context provider command: '%s', %+v\n", rid, matchedHook.ContextProviderCommand, err) + // check if parameters specified in context-provider-command by mistake + if strings.IndexByte(matchedHook.ContextProviderCommand, ' ') != -1 { + s := strings.Fields(matchedHook.ContextProviderCommand)[0] + log.Printf("[%s] please use a wrapper script to provide arguments to context provider command for '%s'\n", rid, s) + } + } else { + contextProviderCommandStdin := struct { + HookID string `json:"hookID"` + Method string `json:"method"` + Body string `json:"body"` + RemoteAddress string `json:"remoteAddress"` + URI string `json:"URI"` + Host string `json:"host"` + Headers http.Header `json:"headers"` + Query url.Values `json:"query"` + }{ + HookID: matchedHook.ID, + Method: r.Method, + Body: string(body), + RemoteAddress: r.RemoteAddr, + URI: r.RequestURI, + Host: r.Host, + Headers: r.Header, + Query: r.URL.Query(), + } + + stdinJSON, err := json.Marshal(contextProviderCommandStdin) + + if err != nil { + log.Printf("[%s] unable to encode context as JSON string for the context provider command: %+v\n", rid, err) + } else { + cmd := exec.Command(cmdPath) + cmd.Dir = matchedHook.CommandWorkingDirectory + cmd.Env = append(os.Environ()) + stdin, err := cmd.StdinPipe() + + if err != nil { + log.Printf("[%s] unable to acquire stdin pipe for the context provider command: %+v\n", rid, err) + } else { + _, err := io.WriteString(stdin, string(stdinJSON)) + stdin.Close() + if err != nil { + log.Printf("[%s] unable to write to context provider command stdin: %+v\n", rid, err) + } else { + log.Printf("[%s] executing context provider command %s (%s) using %s as cwd\n", rid, matchedHook.ContextProviderCommand, cmd.Path, cmd.Dir) + out, err := cmd.CombinedOutput() + + if err != nil { + log.Printf("[%s] unable to execute context provider command: %+v\n", rid, err) + } else { + log.Printf("[%s] got context provider command output: %+v\n", rid, string(out)) + + decoder := json.NewDecoder(strings.NewReader(string(out))) + decoder.UseNumber() + + err := decoder.Decode(&context) + + if err != nil { + log.Printf("[%s] unable to parse context provider command output: %+v\n", rid, err) + } + } + } + } + } + } + } // set contentType to IncomingPayloadContentType or header value contentType := r.Header.Get("Content-Type") @@ -240,6 +319,9 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { contentType = matchedHook.IncomingPayloadContentType } + // parse body + var payload map[string]interface{} + if strings.Contains(contentType, "json") { decoder := json.NewDecoder(strings.NewReader(string(body))) decoder.UseNumber() @@ -259,7 +341,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { } // handle hook - errors := matchedHook.ParseJSONParameters(&headers, &query, &payload) + errors := matchedHook.ParseJSONParameters(&headers, &query, &payload, &context) for _, err := range errors { log.Printf("[%s] error parsing JSON parameters: %s\n", rid, err) } @@ -269,7 +351,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { if matchedHook.TriggerRule == nil { ok = true } else { - ok, err = matchedHook.TriggerRule.Evaluate(&headers, &query, &payload, &body, r.RemoteAddr) + ok, err = matchedHook.TriggerRule.Evaluate(&headers, &query, &payload, &context, &body, r.RemoteAddr) if err != nil { msg := fmt.Sprintf("[%s] error evaluating hook: %s", rid, err) log.Print(msg) @@ -287,7 +369,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { } if matchedHook.CaptureCommandOutput { - response, err := handleHook(matchedHook, rid, &headers, &query, &payload, &body) + response, err := handleHook(matchedHook, rid, &headers, &query, &payload, &context, &body) if err != nil { w.WriteHeader(http.StatusInternalServerError) @@ -305,7 +387,7 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, response) } } else { - go handleHook(matchedHook, rid, &headers, &query, &payload, &body) + go handleHook(matchedHook, rid, &headers, &query, &payload, &context, &body) // Check if a success return code is configured for the hook if matchedHook.SuccessHttpResponseCode != 0 { @@ -332,25 +414,25 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { } } -func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]interface{}, body *[]byte) (string, error) { +func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]interface{}, context *map[string]interface{}, body *[]byte) (string, error) { var errors []error // check the command exists cmdPath, err := exec.LookPath(h.ExecuteCommand) if err != nil { - // give a last chance, maybe is a relative path - relativeToCwd := filepath.Join(h.CommandWorkingDirectory, h.ExecuteCommand) + // give a last chance, maybe is a relative path + relativeToCwd := filepath.Join(h.CommandWorkingDirectory, h.ExecuteCommand) // check the command exists cmdPath, err = exec.LookPath(relativeToCwd) } if err != nil { - log.Printf("unable to locate command: '%s'", h.ExecuteCommand) + log.Printf("[%s] unable to locate command: '%s'\n", rid, h.ExecuteCommand) // check if parameters specified in execute-command by mistake if strings.IndexByte(h.ExecuteCommand, ' ') != -1 { s := strings.Fields(h.ExecuteCommand)[0] - log.Printf("use 'pass-arguments-to-command' to specify args for '%s'", s) + log.Printf("[%s] please use 'pass-arguments-to-command' to specify args for '%s'\n", rid, s) } return "", err @@ -359,19 +441,19 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in cmd := exec.Command(cmdPath) cmd.Dir = h.CommandWorkingDirectory - cmd.Args, errors = h.ExtractCommandArguments(headers, query, payload) + cmd.Args, errors = h.ExtractCommandArguments(headers, query, payload, context) for _, err := range errors { log.Printf("[%s] error extracting command arguments: %s\n", rid, err) } var envs []string - envs, errors = h.ExtractCommandArgumentsForEnv(headers, query, payload) + envs, errors = h.ExtractCommandArgumentsForEnv(headers, query, payload, context) for _, err := range errors { log.Printf("[%s] error extracting command arguments for environment: %s\n", rid, err) } - files, errors := h.ExtractCommandArgumentsForFile(headers, query, payload) + files, errors := h.ExtractCommandArgumentsForFile(headers, query, payload, context) for _, err := range errors { log.Printf("[%s] error extracting command arguments for file: %s\n", rid, err) @@ -380,16 +462,16 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in for i := range files { tmpfile, err := ioutil.TempFile(h.CommandWorkingDirectory, files[i].EnvName) if err != nil { - log.Printf("[%s] error creating temp file [%s]", rid, err) + log.Printf("[%s] error creating temp file [%s]\n", rid, err) continue } log.Printf("[%s] writing env %s file %s", rid, files[i].EnvName, tmpfile.Name()) if _, err := tmpfile.Write(files[i].Data); err != nil { - log.Printf("[%s] error writing file %s [%s]", rid, tmpfile.Name(), err) + log.Printf("[%s] error writing file %s [%s]\n", rid, tmpfile.Name(), err) continue } if err := tmpfile.Close(); err != nil { - log.Printf("[%s] error closing file %s [%s]", rid, tmpfile.Name(), err) + log.Printf("[%s] error closing file %s [%s]\n", rid, tmpfile.Name(), err) continue } @@ -414,7 +496,7 @@ func handleHook(h *hook.Hook, rid string, headers, query, payload *map[string]in log.Printf("[%s] removing file %s\n", rid, files[i].File.Name()) err := os.Remove(files[i].File.Name()) if err != nil { - log.Printf("[%s] error removing file %s [%s]", rid, files[i].File.Name(), err) + log.Printf("[%s] error removing file %s [%s]\n", rid, files[i].File.Name(), err) } } }