diff --git a/go.mod b/go.mod index fad88a4..816766c 100644 --- a/go.mod +++ b/go.mod @@ -30,6 +30,7 @@ require ( github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21 // indirect github.com/envoyproxy/go-control-plane v0.10.1 // indirect github.com/envoyproxy/protoc-gen-validate v0.6.2 // indirect + github.com/gabriel-vasile/mimetype v1.4.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.2 // indirect github.com/google/go-cmp v0.5.6 // indirect diff --git a/go.sum b/go.sum index 91718f4..ef752ff 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ github.com/envoyproxy/go-control-plane v0.10.1/go.mod h1:AY7fTTXNdv/aJ2O5jwpxAPO github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/envoyproxy/protoc-gen-validate v0.6.2 h1:JiO+kJTpmYGjEodY7O1Zk8oZcNz1+f30UtwtXoFUPzE= github.com/envoyproxy/protoc-gen-validate v0.6.2/go.mod h1:2t7qjJNvHPx8IjnBOzl9E9/baC+qXE/TeeyBRzgJDws= +github.com/gabriel-vasile/mimetype v1.4.0 h1:Cn9dkdYsMIu56tGho+fqzh7XmvY2YyGU0FnbhiOsEro= +github.com/gabriel-vasile/mimetype v1.4.0/go.mod h1:fA8fi6KUiG7MgQQ+mEWotXoEOvmxRtOJlERCzSmRvr8= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -323,6 +325,7 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20210503060351-7fd8e65b6420/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210505024714-0287a6fb4125/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d h1:LO7XpTYMwTqxjLcGWPijK3vRXg1aWdlNOVOHRq45d7c= golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= diff --git a/server/server.go b/server/server.go index 82eddae..abf7a37 100644 --- a/server/server.go +++ b/server/server.go @@ -150,9 +150,10 @@ var ( ) const ( - firebaseControlTopic = "~control" // See Android if changed - emptyMessageBody = "triggered" - fcmMessageLimit = 4000 // see maybeTruncateFCMMessage for details + firebaseControlTopic = "~control" // See Android if changed + emptyMessageBody = "triggered" + fcmMessageLimit = 4000 // see maybeTruncateFCMMessage for details + defaultAttachmentMessage = "You received a file: %s" ) // New instantiates a new Server. It creates the cache and adds a Firebase @@ -436,7 +437,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, _ *visitor) return err } defer f.Close() - _, err = io.Copy(util.NewContentTypeWriter(w), f) + _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f) return err } @@ -609,6 +610,9 @@ func (s *Server) handleBodyAsMessage(m *message, body *util.PeakedReadCloser) er if len(body.PeakedBytes) > 0 { // Empty body should not override message (publish via GET!) m.Message = strings.TrimSpace(string(body.PeakedBytes)) // Truncates the message to the peak limit if required } + if m.Attachment != nil && m.Attachment.Name != "" && m.Message == "" { + m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name) + } return nil } @@ -622,16 +626,16 @@ func (s *Server) handleBodyAsAttachment(v *visitor, m *message, body *util.Peake m.Attachment = &attachment{} } var err error + var ext string m.Attachment.Owner = v.ip // Important for attachment rate limiting m.Attachment.Expires = time.Now().Add(s.config.AttachmentExpiryDuration).Unix() - m.Attachment.Type = http.DetectContentType(body.PeakedBytes) - ext := util.ExtensionByType(m.Attachment.Type) + m.Attachment.Type, ext = util.DetectContentType(body.PeakedBytes, m.Attachment.Name) m.Attachment.URL = fmt.Sprintf("%s/file/%s%s", s.config.BaseURL, m.ID, ext) if m.Attachment.Name == "" { m.Attachment.Name = fmt.Sprintf("attachment%s", ext) } if m.Message == "" { - m.Message = fmt.Sprintf("You received a file: %s", m.Attachment.Name) + m.Message = fmt.Sprintf(defaultAttachmentMessage, m.Attachment.Name) } // TODO do not allowed delayed delivery for attachments visitorAttachmentsSize, err := s.cache.AttachmentsSize(v.ip) diff --git a/server/util.go b/server/util.go index b0f0817..d36e397 100644 --- a/server/util.go +++ b/server/util.go @@ -12,8 +12,8 @@ import ( ) const ( - peakAttachmentTimeout = 2500 * time.Millisecond - peakAttachmeantReadBytes = 128 + peakAttachmentTimeout = 2500 * time.Millisecond + peakAttachmentReadBytes = 128 ) func maybePeakAttachmentURL(m *message) error { @@ -47,20 +47,21 @@ func maybePeakAttachmentURLInternal(m *message, timeout time.Duration) error { if size, err := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64); err == nil { m.Attachment.Size = size } + buf := make([]byte, peakAttachmentReadBytes) + io.ReadFull(resp.Body, buf) // Best effort: We don't care about the error + mimeType, ext := util.DetectContentType(buf, m.Attachment.URL) m.Attachment.Type = resp.Header.Get("Content-Type") - if m.Attachment.Type == "" || m.Attachment.Type == "application/octet-stream" { - buf := make([]byte, peakAttachmeantReadBytes) - io.ReadFull(resp.Body, buf) // Best effort: We don't care about the error - m.Attachment.Type = http.DetectContentType(buf) + if m.Attachment.Type == "" { + m.Attachment.Type = mimeType } if m.Attachment.Name == "" { u, err := url.Parse(m.Attachment.URL) if err != nil { - m.Attachment.Name = fmt.Sprintf("attachment%s", util.ExtensionByType(m.Attachment.Type)) + m.Attachment.Name = fmt.Sprintf("attachment%s", ext) } else { m.Attachment.Name = path.Base(u.Path) if m.Attachment.Name == "." || m.Attachment.Name == "/" { - m.Attachment.Name = fmt.Sprintf("attachment%s", util.ExtensionByType(m.Attachment.Type)) + m.Attachment.Name = fmt.Sprintf("attachment%s", ext) } } } diff --git a/util/content_type_writer.go b/util/content_type_writer.go index fb3c43f..7174f02 100644 --- a/util/content_type_writer.go +++ b/util/content_type_writer.go @@ -11,13 +11,14 @@ import ( // It will always set a Content-Type based on http.DetectContentType, but will never send the "text/html" // content type. type ContentTypeWriter struct { - w http.ResponseWriter - sniffed bool + w http.ResponseWriter + filename string + sniffed bool } // NewContentTypeWriter creates a new ContentTypeWriter -func NewContentTypeWriter(w http.ResponseWriter) *ContentTypeWriter { - return &ContentTypeWriter{w, false} +func NewContentTypeWriter(w http.ResponseWriter, filename string) *ContentTypeWriter { + return &ContentTypeWriter{w, filename, false} } func (w *ContentTypeWriter) Write(p []byte) (n int, err error) { @@ -27,7 +28,7 @@ func (w *ContentTypeWriter) Write(p []byte) (n int, err error) { // Detect and set Content-Type header // Fix content types that we don't want to inline-render in the browser. In particular, // we don't want to render HTML in the browser for security reasons. - contentType := http.DetectContentType(p) + contentType, _ := DetectContentType(p, w.filename) if strings.HasPrefix(contentType, "text/html") { contentType = strings.ReplaceAll(contentType, "text/html", "text/plain") } else if contentType == "application/octet-stream" { diff --git a/util/content_type_writer_test.go b/util/content_type_writer_test.go index 08dd751..0fdf65c 100644 --- a/util/content_type_writer_test.go +++ b/util/content_type_writer_test.go @@ -9,14 +9,14 @@ import ( func TestSniffWriter_WriteHTML(t *testing.T) { rr := httptest.NewRecorder() - sw := NewContentTypeWriter(rr) + sw := NewContentTypeWriter(rr, "") sw.Write([]byte("")) require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type")) } func TestSniffWriter_WriteTwoWriteCalls(t *testing.T) { rr := httptest.NewRecorder() - sw := NewContentTypeWriter(rr) + sw := NewContentTypeWriter(rr, "") sw.Write([]byte{0x25, 0x50, 0x44, 0x46, 0x2d, 0x11, 0x22, 0x33}) sw.Write([]byte("")) require.Equal(t, "application/pdf", rr.Header().Get("Content-Type")) @@ -34,7 +34,7 @@ func TestSniffWriter_WriteHTMLSplitIntoTwoWrites(t *testing.T) { // This test shows how splitting the HTML into two Write() calls will still yield text/plain rr := httptest.NewRecorder() - sw := NewContentTypeWriter(rr) + sw := NewContentTypeWriter(rr, "") sw.Write([]byte("alert('hi')")) require.Equal(t, "text/plain; charset=utf-8", rr.Header().Get("Content-Type")) @@ -42,9 +42,16 @@ func TestSniffWriter_WriteHTMLSplitIntoTwoWrites(t *testing.T) { func TestSniffWriter_WriteUnknownMimeType(t *testing.T) { rr := httptest.NewRecorder() - sw := NewContentTypeWriter(rr) + sw := NewContentTypeWriter(rr, "") randomBytes := make([]byte, 199) rand.Read(randomBytes) sw.Write(randomBytes) require.Equal(t, "application/octet-stream", rr.Header().Get("Content-Type")) } + +func TestSniffWriter_WriteWithFilenameAPK(t *testing.T) { + rr := httptest.NewRecorder() + sw := NewContentTypeWriter(rr, "https://example.com/ntfy.apk") + sw.Write([]byte{0x50, 0x4B, 0x03, 0x04}) + require.Equal(t, "application/vnd.android.package-archive", rr.Header().Get("Content-Type")) +} diff --git a/util/util.go b/util/util.go index ae4315b..c6e7623 100644 --- a/util/util.go +++ b/util/util.go @@ -3,8 +3,8 @@ package util import ( "errors" "fmt" + "github.com/gabriel-vasile/mimetype" "math/rand" - "mime" "os" "regexp" "strconv" @@ -21,7 +21,6 @@ var ( random = rand.New(rand.NewSource(time.Now().UnixNano())) randomMutex = sync.Mutex{} sizeStrRegex = regexp.MustCompile(`(?i)^(\d+)([gmkb])?$`) - extRegex = regexp.MustCompile(`^\.[-_A-Za-z0-9]+$`) errInvalidPriority = errors.New("invalid priority") ) @@ -168,20 +167,18 @@ func ShortTopicURL(s string) string { return strings.TrimPrefix(strings.TrimPrefix(s, "https://"), "http://") } -// ExtensionByType is a wrapper around mime.ExtensionByType with a few sensible corrections -func ExtensionByType(contentType string) string { - switch contentType { - case "image/jpeg": - return ".jpg" - case "video/mp4": - return ".mp4" - default: - exts, err := mime.ExtensionsByType(contentType) - if err == nil && len(exts) > 0 && extRegex.MatchString(exts[0]) { - return exts[0] - } - return ".bin" +// DetectContentType probes the byte array b and returns mime type and file extension. +// The filename is only used to override certain special cases. +func DetectContentType(b []byte, filename string) (mimeType string, ext string) { + if strings.HasSuffix(strings.ToLower(filename), ".apk") { + return "application/vnd.android.package-archive", ".apk" } + m := mimetype.Detect(b) + mimeType, ext = m.String(), m.Extension() + if ext == "" { + ext = ".bin" + } + return } // ParseSize parses a size string like 2K or 2M into bytes. If no unit is found, e.g. 123, bytes is assumed.