Allow HEAD requests for file attachments
This commit is contained in:
parent
10c89b2e55
commit
2b42cea1a3
2 changed files with 22 additions and 9 deletions
|
@ -8,7 +8,6 @@ import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"heckel.io/ntfy/log"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -23,6 +22,8 @@ import (
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
|
|
||||||
"github.com/emersion/go-smtp"
|
"github.com/emersion/go-smtp"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
@ -289,7 +290,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
||||||
return s.ensureWebEnabled(s.handleStatic)(w, r, v)
|
return s.ensureWebEnabled(s.handleStatic)(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
||||||
return s.ensureWebEnabled(s.handleDocs)(w, r, v)
|
return s.ensureWebEnabled(s.handleDocs)(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
|
} else if (r.Method == http.MethodGet || r.Method == http.MethodHead) && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
|
||||||
return s.limitRequests(s.handleFile)(w, r, v)
|
return s.limitRequests(s.handleFile)(w, r, v)
|
||||||
} else if r.Method == http.MethodOptions {
|
} else if r.Method == http.MethodOptions {
|
||||||
return s.ensureWebEnabled(s.handleOptions)(w, r, v)
|
return s.ensureWebEnabled(s.handleOptions)(w, r, v)
|
||||||
|
@ -405,11 +406,14 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
|
if err := v.BandwidthLimiter().Allow(stat.Size()); err != nil {
|
||||||
return errHTTPTooManyRequestsAttachmentBandwidthLimit
|
return errHTTPTooManyRequestsAttachmentBandwidthLimit
|
||||||
}
|
}
|
||||||
|
}
|
||||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||||
|
if r.Method == http.MethodGet {
|
||||||
f, err := os.Open(file)
|
f, err := os.Open(file)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -417,6 +421,8 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
|
_, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f)
|
||||||
return err
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
|
|
|
@ -1026,12 +1026,19 @@ func TestServer_PublishAttachment(t *testing.T) {
|
||||||
require.Equal(t, "", msg.Sender) // Should never be returned
|
require.Equal(t, "", msg.Sender) // Should never be returned
|
||||||
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
|
require.FileExists(t, filepath.Join(s.config.AttachmentCacheDir, msg.ID))
|
||||||
|
|
||||||
|
// GET
|
||||||
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
|
path := strings.TrimPrefix(msg.Attachment.URL, "http://127.0.0.1:12345")
|
||||||
response = request(t, s, "GET", path, "", nil)
|
response = request(t, s, "GET", path, "", nil)
|
||||||
require.Equal(t, 200, response.Code)
|
require.Equal(t, 200, response.Code)
|
||||||
require.Equal(t, "5000", response.Header().Get("Content-Length"))
|
require.Equal(t, "5000", response.Header().Get("Content-Length"))
|
||||||
require.Equal(t, content, response.Body.String())
|
require.Equal(t, content, response.Body.String())
|
||||||
|
|
||||||
|
// HEAD
|
||||||
|
response = request(t, s, "HEAD", path, "", nil)
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
require.Equal(t, "5000", response.Header().Get("Content-Length"))
|
||||||
|
require.Equal(t, "", response.Body.String())
|
||||||
|
|
||||||
// Slightly unrelated cross-test: make sure we add an owner for internal attachments
|
// Slightly unrelated cross-test: make sure we add an owner for internal attachments
|
||||||
size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request()
|
size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request()
|
||||||
require.Nil(t, err)
|
require.Nil(t, err)
|
||||||
|
|
Loading…
Reference in a new issue