diff --git a/server/server.go b/server/server.go index 0b909b8..f7811c1 100644 --- a/server/server.go +++ b/server/server.go @@ -8,7 +8,6 @@ import ( "encoding/base64" "encoding/json" "fmt" - "heckel.io/ntfy/log" "io" "net" "net/http" @@ -23,6 +22,8 @@ import ( "time" "unicode/utf8" + "heckel.io/ntfy/log" + "github.com/emersion/go-smtp" "github.com/gorilla/websocket" "golang.org/x/sync/errgroup" @@ -289,7 +290,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit return s.ensureWebEnabled(s.handleStatic)(w, r, v) } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { return s.ensureWebEnabled(s.handleDocs)(w, r, v) - } else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" { + } else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" { return s.limitRequests(s.handleFile)(w, r, v) } else if r.Method == http.MethodOptions { return s.ensureWebEnabled(s.handleOptions)(w, r, v) @@ -405,18 +406,23 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) if err != nil { return errHTTPNotFound } - if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil { - return errHTTPTooManyRequestsAttachmentBandwidthLimit + if r.Method == http.MethodGet { + if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil { + return errHTTPTooManyRequestsAttachmentBandwidthLimit + } } w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size())) w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests - f, err := os.Open(file) - if err != nil { + if r.Method == http.MethodGet { + f, err := os.Open(file) + if err != nil { + return err + } + defer f.Close() + _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f) return err } - defer f.Close() - _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f) - return err + return nil } func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { diff --git a/server/server_test.go b/server/server_test.go index 8010643..32f4fc2 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1026,12 +1026,19 @@ func TestServer_PublishAttachment(t *testing.T) { require.Equal(t, "", msg.Sender) // Should never be returned require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID)) + // GET path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345") response = request(t, s, "GET", path, "", nil) require.Equal(t, 200, response.Code) require.Equal(t, "5000", response.Header().Get("Content-Length")) require.Equal(t, content, response.Body.String()) + // HEAD + response = request(t, s, "HEAD", path, "", nil) + require.Equal(t, 200, response.Code) + require.Equal(t, "5000", response.Header().Get("Content-Length")) + require.Equal(t, "", response.Body.String()) + // Slightly unrelated cross-test: make sure we add an owner for internal attachments size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request() require.Nil(t, err)