Refactor cli HTTP methods behavior

The CLI HTTP methods option now sets the default allowed methods while
allowing an individual hook definition to override the default.
This commit is contained in:
Cameron Moore 2019-12-27 11:22:04 -06:00
parent e1249a9ddb
commit 157f468e0c

View file

@ -50,7 +50,7 @@ var (
maxMultipartMem = flag.Int64("max-multipart-mem", 1<<20, "maximum memory in bytes for parsing multipart form data before disk caching")
setGID = flag.Int("setgid", 0, "set group ID after opening listening port; must be used with setuid")
setUID = flag.Int("setuid", 0, "set user ID after opening listening port; must be used with setgid")
httpMethods = flag.String("http-methods", "", "globally restrict allowed HTTP methods; separate methods with comma")
httpMethods = flag.String("http-methods", "", `set default allowed HTTP methods (ie. "POST"); separate methods with comma`)
responseHeaders hook.ResponseHeaders
hooksFiles hook.HooksFiles
@ -198,24 +198,16 @@ func main() {
r.Use(middleware.Dumper(log.Writer()))
}
// Clean up input
*httpMethods = strings.ToUpper(strings.ReplaceAll(*httpMethods, " ", ""))
hooksURL := makeURL(hooksURLPrefix)
r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprint(w, "OK")
})
var allowedMethods []string
if *httpMethods == "" {
r.HandleFunc(hooksURL, hookHandler)
} else {
allowedMethods = strings.Split(*httpMethods, ",")
for i := range allowedMethods {
allowedMethods[i] = strings.TrimSpace(allowedMethods[i])
}
r.HandleFunc(hooksURL, hookHandler).Methods(allowedMethods...)
}
r.HandleFunc(hooksURL, hookHandler)
addr := fmt.Sprintf("%s:%d", *ip, *port)
@ -242,13 +234,9 @@ func main() {
// Serve HTTP
if !*secure {
if len(allowedMethods) == 0 {
log.Printf("serving hooks on http://%s%s", addr, hooksURL)
} else {
log.Printf("serving hooks on http://%s%s for %s", addr, hooksURL, strings.Join(allowedMethods, ", "))
}
log.Printf("serving hooks on http://%s%s", addr, hooksURL)
log.Print(svr.Serve(ln))
return
}
@ -261,11 +249,7 @@ func main() {
}
svr.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2
if len(allowedMethods) == 0 {
log.Printf("serving hooks on https://%s%s", addr, hooksURL)
} else {
log.Printf("serving hooks on https://%s%s for %s", addr, hooksURL, strings.Join(allowedMethods, ", "))
}
log.Printf("serving hooks on https://%s%s", addr, hooksURL)
log.Print(svr.ServeTLS(ln, *cert, *key))
}
@ -284,24 +268,34 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
}
// Check for allowed methods
if len(matchedHook.HTTPMethods) != 0 {
var allowed bool
var allowedMethod bool
switch {
case len(matchedHook.HTTPMethods) != 0:
for i := range matchedHook.HTTPMethods {
// TODO(moorereason): refactor config loading and reloading to
// sanitize these methods once at load time.
matchedHook.HTTPMethods[i] = strings.ToUpper(strings.TrimSpace(matchedHook.HTTPMethods[i]))
if matchedHook.HTTPMethods[i] == r.Method {
allowed = true
if r.Method == strings.ToUpper(strings.TrimSpace(matchedHook.HTTPMethods[i])) {
allowedMethod = true
break
}
}
if !allowed {
w.WriteHeader(http.StatusMethodNotAllowed)
log.Printf("[%s] HTTP %s method not implemented for hook %q", rid, r.Method, id)
return
case *httpMethods != "":
for _, v := range strings.Split(*httpMethods, ",") {
if r.Method == v {
allowedMethod = true
break
}
}
default:
allowedMethod = true
}
if !allowedMethod {
w.WriteHeader(http.StatusMethodNotAllowed)
log.Printf("[%s] HTTP %s method not implemented for hook %q", rid, r.Method, id)
return
}
log.Printf("[%s] %s got matched\n", rid, id)