Attachment behavior fix for Firefox

This commit is contained in:
Philipp Heckel 2022-04-03 12:39:52 -04:00
parent f98743dd9b
commit aba7e86cbc
13 changed files with 223 additions and 123 deletions

View file

@ -355,7 +355,7 @@ func (c *messageCache) Prune(olderThan time.Time) error {
return err
}
func (c *messageCache) AttachmentsSize(owner string) (int64, error) {
func (c *messageCache) AttachmentBytesUsed(owner string) (int64, error) {
rows, err := c.db.Query(selectAttachmentsSizeQuery, owner, time.Now().Unix())
if err != nil {
return 0, err

View file

@ -337,11 +337,11 @@ func testCacheAttachments(t *testing.T, c *messageCache) {
require.Equal(t, "https://ntfy.sh/file/aCaRURL.jpg", messages[1].Attachment.URL)
require.Equal(t, "1.2.3.4", messages[1].Attachment.Owner)
size, err := c.AttachmentsSize("1.2.3.4")
size, err := c.AttachmentBytesUsed("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(30000), size)
size, err = c.AttachmentsSize("5.6.7.8")
size, err = c.AttachmentBytesUsed("5.6.7.8")
require.Nil(t, err)
require.Equal(t, int64(0), size)

View file

@ -66,6 +66,7 @@ var (
publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}/(publish|send|trigger)$`)
webConfigPath = "/config.js"
userStatsPath = "/user/stats"
staticRegex = regexp.MustCompile(`^/static/.+`)
docsRegex = regexp.MustCompile(`^/docs(|/.*)$`)
fileRegex = regexp.MustCompile(`^/file/([-_A-Za-z0-9]{1,64})(?:\.[A-Za-z0-9]{1,16})?$`)
@ -269,6 +270,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.handleEmpty(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
return s.handleWebConfig(w, r)
} else if r.Method == http.MethodGet && r.URL.Path == userStatsPath {
return s.handleUserStats(w, r, v)
} else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) {
return s.handleStatic(w, r)
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
@ -351,6 +354,19 @@ var config = {
return err
}
func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visitor) error {
stats, err := v.Stats()
if err != nil {
return err
}
w.Header().Set("Content-Type", "text/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
if err := json.NewEncoder(w).Encode(stats); err != nil {
return err
}
return nil
}
func (s *Server) handleStatic(w http.ResponseWriter, r *http.Request) error {
r.URL.Path = webSiteDir + r.URL.Path
util.Gzip(http.FileServer(http.FS(webFsCached))).ServeHTTP(w, r)
@ -395,8 +411,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
if err != nil {
return err
}
return errHTTPEntityTooLargeAttachmentTooLarge
body, err := util.Peak(r.Body, s.config.MessageLimit)
body, err := util.Peek(r.Body, s.config.MessageLimit)
if err != nil {
return err
}
@ -540,35 +555,35 @@ func (s *Server) parsePublishParams(r *http.Request, v *visitor, m *message) (ca
// If file.txt is <= 4096 (message limit) and valid UTF-8, treat it as a message
// 5. curl -T file.txt ntfy.sh/mytopic
// If file.txt is > message limit, treat it as an attachment
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser, unifiedpush bool) error {
func (s *Server) handlePublishBody(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser, unifiedpush bool) error {
if unifiedpush {
return s.handleBodyAsMessageAutoDetect(m, body) // Case 1
} else if m.Attachment != nil && m.Attachment.URL != "" {
return s.handleBodyAsTextMessage(m, body) // Case 2
} else if m.Attachment != nil && m.Attachment.Name != "" {
return s.handleBodyAsAttachment(r, v, m, body) // Case 3
} else if !body.LimitReached && utf8.Valid(body.PeakedBytes) {
} else if !body.LimitReached && utf8.Valid(body.PeekedBytes) {
return s.handleBodyAsTextMessage(m, body) // Case 4
}
return s.handleBodyAsAttachment(r, v, m, body) // Case 5
}
func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeakedReadCloser) error {
if utf8.Valid(body.PeakedBytes) {
m.Message = string(body.PeakedBytes) // Do not trim
func (s *Server) handleBodyAsMessageAutoDetect(m *message, body *util.PeekedReadCloser) error {
if utf8.Valid(body.PeekedBytes) {
m.Message = string(body.PeekedBytes) // Do not trim
} else {
m.Message = base64.StdEncoding.EncodeToString(body.PeakedBytes)
m.Message = base64.StdEncoding.EncodeToString(body.PeekedBytes)
m.Encoding = encodingBase64
}
return nil
}
func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeakedReadCloser) error {
if !utf8.Valid(body.PeakedBytes) {
func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeekedReadCloser) error {
if !utf8.Valid(body.PeekedBytes) {
return errHTTPBadRequestMessageNotUTF8
}
if len(body.PeakedBytes) > 0 { // Empty body should not override message (publish via GET!)
m.Message = strings.TrimSpace(string(body.PeakedBytes)) // Truncates the message to the peak limit if required
if len(body.PeekedBytes) > 0 { // Empty body should not override message (publish via GET!)
m.Message = strings.TrimSpace(string(body.PeekedBytes)) // Truncates the message to the peek limit if required
}
if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" {
m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
@ -576,21 +591,20 @@ func (s *Server) handleBodyAsTextMessage(m *message, body *util.PeakedReadCloser
return nil
}
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeakedReadCloser) error {
func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message, body *util.PeekedReadCloser) error {
if s.fileCache == nil || s.config.BaseURL == "" || s.config.AttachmentCacheDir == "" {
return errHTTPBadRequestAttachmentsDisallowed
} else if m.Time > time.Now().Add(s.config.AttachmentExpiryDuration).Unix() {
return errHTTPBadRequestAttachmentsExpiryBeforeDelivery
}
visitorAttachmentsSize, err := s.messageCache.AttachmentsSize(v.ip)
visitorStats, err := v.Stats()
if err != nil {
return err
}
remainingVisitorAttachmentSize := s.config.VisitorAttachmentTotalSizeLimit - visitorAttachmentsSize
contentLengthStr := r.Header.Get("Content-Length")
if contentLengthStr != "" { // Early "do-not-trust" check, hard limit see below
contentLength, err := strconv.ParseInt(contentLengthStr, 10, 64)
if err == nil && (contentLength > remainingVisitorAttachmentSize || contentLength > s.config.AttachmentFileSizeLimit) {
if err == nil && (contentLength > visitorStats.VisitorAttachmentBytesRemaining || contentLength > s.config.AttachmentFileSizeLimit) {
return errHTTPEntityTooLargeAttachmentTooLarge
}
}
@ -600,7 +614,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
var ext string
m.Attachment.Owner = v.ip // Important for attachment rate limiting
m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix()
m.Attachment.Type, ext = util.DetectContentType(body.PeakedBytes, m.Attachment.Name)
m.Attachment.Type, ext = util.DetectContentType(body.PeekedBytes, m.Attachment.Name)
m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext)
if m.Attachment.Name == "" {
m.Attachment.Name = fmt.Sprintf("attachment%s", ext)
@ -608,7 +622,7 @@ func (s *Server) handleBodyAsAttachment(r *http.Request, v *visitor, m *message,
if m.Message == "" {
m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name)
}
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(remainingVisitorAttachmentSize))
m.Attachment.Size, err = s.fileCache.Write(m.ID, body, v.BandwidthLimiter(), util.NewFixedLimiter(visitorStats.VisitorAttachmentBytesRemaining))
if err == util.ErrLimitReached {
return errHTTPEntityTooLargeAttachmentTooLarge
} else if err != nil {
@ -1097,11 +1111,11 @@ func (s *Server) limitRequests(next handleFunc) handleFunc {
}
}
// transformBodyJSON peaks the request body, reads the JSON, and converts it to headers
// transformBodyJSON peeks the request body, reads the JSON, and converts it to headers
// before passing it on to the next handler. This is meant to be used in combination with handlePublish.
func (s *Server) transformBodyJSON(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
body, err := util.Peak(r.Body, s.config.MessageLimit)
body, err := util.Peek(r.Body, s.config.MessageLimit)
if err != nil {
return err
}
@ -1217,7 +1231,7 @@ func (s *Server) visitor(r *http.Request) *visitor {
}
v, exists := s.visitors[ip]
if !exists {
s.visitors[ip] = newVisitor(s.config, ip)
s.visitors[ip] = newVisitor(s.config, s.messageCache, ip)
return s.visitors[ip]
}
v.Keepalive()

View file

@ -938,7 +938,7 @@ func TestServer_PublishAttachment(t *testing.T) {
require.Equal(t, content, response.Body.String())
// Slightly unrelated cross-test: make sure we add an owner for internal attachments
size, err := s.messageCache.AttachmentsSize("9.9.9.9") // See request()
size, err := s.messageCache.AttachmentBytesUsed("9.9.9.9") // See request()
require.Nil(t, err)
require.Equal(t, int64(5000), size)
}
@ -967,7 +967,7 @@ func TestServer_PublishAttachmentShortWithFilename(t *testing.T) {
require.Equal(t, content, response.Body.String())
// Slightly unrelated cross-test: make sure we add an owner for internal attachments
size, err := s.messageCache.AttachmentsSize("1.2.3.4")
size, err := s.messageCache.AttachmentBytesUsed("1.2.3.4")
require.Nil(t, err)
require.Equal(t, int64(21), size)
}
@ -987,7 +987,7 @@ func TestServer_PublishAttachmentExternalWithoutFilename(t *testing.T) {
require.Equal(t, "", msg.Attachment.Owner)
// Slightly unrelated cross-test: make sure we don't add an owner for external attachments
size, err := s.messageCache.AttachmentsSize("127.0.0.1")
size, err := s.messageCache.AttachmentBytesUsed("127.0.0.1")
require.Nil(t, err)
require.Equal(t, int64(0), size)
}

View file

@ -22,6 +22,7 @@ var (
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
type visitor struct {
config *Config
messageCache *messageCache
ip string
requests *rate.Limiter
emails *rate.Limiter
@ -31,9 +32,17 @@ type visitor struct {
mu sync.Mutex
}
func newVisitor(conf *Config, ip string) *visitor {
type visitorStats struct {
AttachmentFileSizeLimit int64 `json:"attachmentFileSizeLimit"`
VisitorAttachmentBytesTotal int64 `json:"visitorAttachmentBytesTotal"`
VisitorAttachmentBytesUsed int64 `json:"visitorAttachmentBytesUsed"`
VisitorAttachmentBytesRemaining int64 `json:"visitorAttachmentBytesRemaining"`
}
func newVisitor(conf *Config, messageCache *messageCache, ip string) *visitor {
return &visitor{
config: conf,
messageCache: messageCache,
ip: ip,
requests: rate.NewLimiter(rate.Every(conf.VisitorRequestLimitReplenish), conf.VisitorRequestLimitBurst),
emails: rate.NewLimiter(rate.Every(conf.VisitorEmailLimitReplenish), conf.VisitorEmailLimitBurst),
@ -91,3 +100,20 @@ func (v *visitor) Stale() bool {
defer v.mu.Unlock()
return time.Since(v.seen) > visitorExpungeAfter
}
func (v *visitor) Stats() (*visitorStats, error) {
attachmentsBytesUsed, err := v.messageCache.AttachmentBytesUsed(v.ip)
if err != nil {
return nil, err
}
attachmentsBytesRemaining := v.config.VisitorAttachmentTotalSizeLimit - attachmentsBytesUsed
if attachmentsBytesRemaining < 0 {
attachmentsBytesRemaining = 0
}
return &visitorStats{
AttachmentFileSizeLimit: v.config.AttachmentFileSizeLimit,
VisitorAttachmentBytesTotal: v.config.VisitorAttachmentTotalSizeLimit,
VisitorAttachmentBytesUsed: attachmentsBytesUsed,
VisitorAttachmentBytesRemaining: attachmentsBytesRemaining,
}, nil
}