From 157f468e0cf3ca17c23877eed2015ffdcb00c5fa Mon Sep 17 00:00:00 2001 From: Cameron Moore Date: Fri, 27 Dec 2019 11:22:04 -0600 Subject: [PATCH] Refactor cli HTTP methods behavior The CLI HTTP methods option now sets the default allowed methods while allowing an individual hook definition to override the default. --- webhook.go | 64 +++++++++++++++++++++++++----------------------------- 1 file changed, 29 insertions(+), 35 deletions(-) diff --git a/webhook.go b/webhook.go index 13b5595..5df2cfd 100644 --- a/webhook.go +++ b/webhook.go @@ -50,7 +50,7 @@ var ( maxMultipartMem = flag.Int64("max-multipart-mem", 1<<20, "maximum memory in bytes for parsing multipart form data before disk caching") setGID = flag.Int("setgid", 0, "set group ID after opening listening port; must be used with setuid") setUID = flag.Int("setuid", 0, "set user ID after opening listening port; must be used with setgid") - httpMethods = flag.String("http-methods", "", "globally restrict allowed HTTP methods; separate methods with comma") + httpMethods = flag.String("http-methods", "", `set default allowed HTTP methods (ie. "POST"); separate methods with comma`) responseHeaders hook.ResponseHeaders hooksFiles hook.HooksFiles @@ -198,24 +198,16 @@ func main() { r.Use(middleware.Dumper(log.Writer())) } + // Clean up input + *httpMethods = strings.ToUpper(strings.ReplaceAll(*httpMethods, " ", "")) + hooksURL := makeURL(hooksURLPrefix) r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { fmt.Fprint(w, "OK") }) - var allowedMethods []string - - if *httpMethods == "" { - r.HandleFunc(hooksURL, hookHandler) - } else { - allowedMethods = strings.Split(*httpMethods, ",") - for i := range allowedMethods { - allowedMethods[i] = strings.TrimSpace(allowedMethods[i]) - } - - r.HandleFunc(hooksURL, hookHandler).Methods(allowedMethods...) - } + r.HandleFunc(hooksURL, hookHandler) addr := fmt.Sprintf("%s:%d", *ip, *port) @@ -242,13 +234,9 @@ func main() { // Serve HTTP if !*secure { - if len(allowedMethods) == 0 { - log.Printf("serving hooks on http://%s%s", addr, hooksURL) - } else { - log.Printf("serving hooks on http://%s%s for %s", addr, hooksURL, strings.Join(allowedMethods, ", ")) - } - + log.Printf("serving hooks on http://%s%s", addr, hooksURL) log.Print(svr.Serve(ln)) + return } @@ -261,11 +249,7 @@ func main() { } svr.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2 - if len(allowedMethods) == 0 { - log.Printf("serving hooks on https://%s%s", addr, hooksURL) - } else { - log.Printf("serving hooks on https://%s%s for %s", addr, hooksURL, strings.Join(allowedMethods, ", ")) - } + log.Printf("serving hooks on https://%s%s", addr, hooksURL) log.Print(svr.ServeTLS(ln, *cert, *key)) } @@ -284,24 +268,34 @@ func hookHandler(w http.ResponseWriter, r *http.Request) { } // Check for allowed methods - if len(matchedHook.HTTPMethods) != 0 { - var allowed bool + var allowedMethod bool + + switch { + case len(matchedHook.HTTPMethods) != 0: for i := range matchedHook.HTTPMethods { // TODO(moorereason): refactor config loading and reloading to // sanitize these methods once at load time. - matchedHook.HTTPMethods[i] = strings.ToUpper(strings.TrimSpace(matchedHook.HTTPMethods[i])) - - if matchedHook.HTTPMethods[i] == r.Method { - allowed = true + if r.Method == strings.ToUpper(strings.TrimSpace(matchedHook.HTTPMethods[i])) { + allowedMethod = true break } } - - if !allowed { - w.WriteHeader(http.StatusMethodNotAllowed) - log.Printf("[%s] HTTP %s method not implemented for hook %q", rid, r.Method, id) - return + case *httpMethods != "": + for _, v := range strings.Split(*httpMethods, ",") { + if r.Method == v { + allowedMethod = true + break + } } + default: + allowedMethod = true + } + + if !allowedMethod { + w.WriteHeader(http.StatusMethodNotAllowed) + log.Printf("[%s] HTTP %s method not implemented for hook %q", rid, r.Method, id) + + return } log.Printf("[%s] %s got matched\n", rid, id)