Compare commits

...

1 commit

Author SHA1 Message Date
binwiederhier
db1a1fec0c Custom HTTP response writer 2023-02-17 09:07:57 -05:00
2 changed files with 58 additions and 0 deletions

View file

@ -291,6 +291,7 @@ func (s *Server) closeDatabases() {
// handle is the main entry point for all HTTP requests
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
w = newHTTPResponseWriter(w) // Avoid logging "superfluous response.WriteHeader call" warning
v, err := s.maybeAuthenticate(r) // Note: Always returns v, even when error is returned
if err != nil {
s.handleError(w, r, v, err)

View file

@ -1,11 +1,14 @@
package server
import (
"bufio"
"heckel.io/ntfy/util"
"io"
"net"
"net/http"
"net/netip"
"strings"
"sync"
)
func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
@ -85,3 +88,57 @@ func readJSONWithLimit[T any](r io.ReadCloser, limit int, allowEmpty bool) (*T,
}
return obj, nil
}
type httpResponseWriter struct {
w http.ResponseWriter
headerWritten bool
mu sync.Mutex
}
type httpResponseWriterWithHijacker struct {
httpResponseWriter
}
var _ http.ResponseWriter = (*httpResponseWriter)(nil)
var _ http.Flusher = (*httpResponseWriter)(nil)
var _ http.Hijacker = (*httpResponseWriterWithHijacker)(nil)
func newHTTPResponseWriter(w http.ResponseWriter) http.ResponseWriter {
if _, ok := w.(http.Hijacker); ok {
return &httpResponseWriterWithHijacker{httpResponseWriter: httpResponseWriter{w: w}}
}
return &httpResponseWriter{w: w}
}
func (w *httpResponseWriter) Header() http.Header {
return w.w.Header()
}
func (w *httpResponseWriter) Write(bytes []byte) (int, error) {
w.mu.Lock()
w.headerWritten = true
w.mu.Unlock()
return w.w.Write(bytes)
}
func (w *httpResponseWriter) WriteHeader(statusCode int) {
w.mu.Lock()
if w.headerWritten {
w.mu.Unlock()
return
}
w.headerWritten = true
w.mu.Unlock()
w.w.WriteHeader(statusCode)
}
func (w *httpResponseWriter) Flush() {
if f, ok := w.w.(http.Flusher); ok {
f.Flush()
}
}
func (w *httpResponseWriterWithHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
h, _ := w.w.(http.Hijacker)
return h.Hijack()
}