From db1a1fec0c5dbae8376d0f6b5047200578fce869 Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Fri, 17 Feb 2023 09:07:57 -0500 Subject: [PATCH] Custom HTTP response writer --- server/server.go | 1 + server/util.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/server/server.go b/server/server.go index 4588614..b9e7b17 100644 --- a/server/server.go +++ b/server/server.go @@ -291,6 +291,7 @@ func (s *Server) closeDatabases() { // handle is the main entry point for all HTTP requests func (s *Server) handle(w http.ResponseWriter, r *http.Request) { + w = newHTTPResponseWriter(w) // Avoid logging "superfluous response.WriteHeader call" warning v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned if err != nil { s.handleError(w, r, v, err) diff --git a/server/util.go b/server/util.go index 1141e5d..99c5e1b 100644 --- a/server/util.go +++ b/server/util.go @@ -1,11 +1,14 @@ package server import ( + "bufio" "heckel.io/ntfy/util" "io" + "net" "net/http" "net/netip" "strings" + "sync" ) func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool { @@ -85,3 +88,57 @@ func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T, } return obj, nil } + +type httpResponseWriter struct { + w http.ResponseWriter + headerWritten bool + mu sync.Mutex +} + +type httpResponseWriterWithHijacker struct { + httpResponseWriter +} + +var _ http.ResponseWriter = (*httpResponseWriter)(nil) +var _ http.Flusher = (*httpResponseWriter)(nil) +var _ http.Hijacker = (*httpResponseWriterWithHijacker)(nil) + +func newHTTPResponseWriter(w http.ResponseWriter) http.ResponseWriter { + if _, ok := w.(http.Hijacker); ok { + return &httpResponseWriterWithHijacker{httpResponseWriter: httpResponseWriter{w: w}} + } + return &httpResponseWriter{w: w} +} + +func (w *httpResponseWriter) Header() http.Header { + return w.w.Header() +} + +func (w *httpResponseWriter) Write(bytes []byte) (int, error) { + w.mu.Lock() + w.headerWritten = true + w.mu.Unlock() + return w.w.Write(bytes) +} + +func (w *httpResponseWriter) WriteHeader(statusCode int) { + w.mu.Lock() + if w.headerWritten { + w.mu.Unlock() + return + } + w.headerWritten = true + w.mu.Unlock() + w.w.WriteHeader(statusCode) +} + +func (w *httpResponseWriter) Flush() { + if f, ok := w.w.(http.Flusher); ok { + f.Flush() + } +} + +func (w *httpResponseWriterWithHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + h, _ := w.w.(http.Hijacker) + return h.Hijack() +}