Refactor cli HTTP methods behavior

The CLI HTTP methods option now sets the default allowed methods while
allowing an individual hook definition to override the default.
This commit is contained in:
Cameron Moore 2019-12-27 11:22:04 -06:00
parent e1249a9ddb
commit 157f468e0c

View file

@ -50,7 +50,7 @@ var (
maxMultipartMem = flag.Int64("max-multipart-mem", 1<<20, "maximum memory in bytes for parsing multipart form data before disk caching") maxMultipartMem = flag.Int64("max-multipart-mem", 1<<20, "maximum memory in bytes for parsing multipart form data before disk caching")
setGID = flag.Int("setgid", 0, "set group ID after opening listening port; must be used with setuid") setGID = flag.Int("setgid", 0, "set group ID after opening listening port; must be used with setuid")
setUID = flag.Int("setuid", 0, "set user ID after opening listening port; must be used with setgid") setUID = flag.Int("setuid", 0, "set user ID after opening listening port; must be used with setgid")
httpMethods = flag.String("http-methods", "", "globally restrict allowed HTTP methods; separate methods with comma") httpMethods = flag.String("http-methods", "", `set default allowed HTTP methods (ie. "POST"); separate methods with comma`)
responseHeaders hook.ResponseHeaders responseHeaders hook.ResponseHeaders
hooksFiles hook.HooksFiles hooksFiles hook.HooksFiles
@ -198,24 +198,16 @@ func main() {
r.Use(middleware.Dumper(log.Writer())) r.Use(middleware.Dumper(log.Writer()))
} }
// Clean up input
*httpMethods = strings.ToUpper(strings.ReplaceAll(*httpMethods, " ", ""))
hooksURL := makeURL(hooksURLPrefix) hooksURL := makeURL(hooksURLPrefix)
r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) { r.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprint(w, "OK") fmt.Fprint(w, "OK")
}) })
var allowedMethods []string
if *httpMethods == "" {
r.HandleFunc(hooksURL, hookHandler) r.HandleFunc(hooksURL, hookHandler)
} else {
allowedMethods = strings.Split(*httpMethods, ",")
for i := range allowedMethods {
allowedMethods[i] = strings.TrimSpace(allowedMethods[i])
}
r.HandleFunc(hooksURL, hookHandler).Methods(allowedMethods...)
}
addr := fmt.Sprintf("%s:%d", *ip, *port) addr := fmt.Sprintf("%s:%d", *ip, *port)
@ -242,13 +234,9 @@ func main() {
// Serve HTTP // Serve HTTP
if !*secure { if !*secure {
if len(allowedMethods) == 0 {
log.Printf("serving hooks on http://%s%s", addr, hooksURL) log.Printf("serving hooks on http://%s%s", addr, hooksURL)
} else {
log.Printf("serving hooks on http://%s%s for %s", addr, hooksURL, strings.Join(allowedMethods, ", "))
}
log.Print(svr.Serve(ln)) log.Print(svr.Serve(ln))
return return
} }
@ -261,11 +249,7 @@ func main() {
} }
svr.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2 svr.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler)) // disable http/2
if len(allowedMethods) == 0 {
log.Printf("serving hooks on https://%s%s", addr, hooksURL) log.Printf("serving hooks on https://%s%s", addr, hooksURL)
} else {
log.Printf("serving hooks on https://%s%s for %s", addr, hooksURL, strings.Join(allowedMethods, ", "))
}
log.Print(svr.ServeTLS(ln, *cert, *key)) log.Print(svr.ServeTLS(ln, *cert, *key))
} }
@ -284,25 +268,35 @@ func hookHandler(w http.ResponseWriter, r *http.Request) {
} }
// Check for allowed methods // Check for allowed methods
if len(matchedHook.HTTPMethods) != 0 { var allowedMethod bool
var allowed bool
switch {
case len(matchedHook.HTTPMethods) != 0:
for i := range matchedHook.HTTPMethods { for i := range matchedHook.HTTPMethods {
// TODO(moorereason): refactor config loading and reloading to // TODO(moorereason): refactor config loading and reloading to
// sanitize these methods once at load time. // sanitize these methods once at load time.
matchedHook.HTTPMethods[i] = strings.ToUpper(strings.TrimSpace(matchedHook.HTTPMethods[i])) if r.Method == strings.ToUpper(strings.TrimSpace(matchedHook.HTTPMethods[i])) {
allowedMethod = true
if matchedHook.HTTPMethods[i] == r.Method {
allowed = true
break break
} }
} }
case *httpMethods != "":
for _, v := range strings.Split(*httpMethods, ",") {
if r.Method == v {
allowedMethod = true
break
}
}
default:
allowedMethod = true
}
if !allowed { if !allowedMethod {
w.WriteHeader(http.StatusMethodNotAllowed) w.WriteHeader(http.StatusMethodNotAllowed)
log.Printf("[%s] HTTP %s method not implemented for hook %q", rid, r.Method, id) log.Printf("[%s] HTTP %s method not implemented for hook %q", rid, r.Method, id)
return return
} }
}
log.Printf("[%s] %s got matched\n", rid, id) log.Printf("[%s] %s got matched\n", rid, id)