From f4c54a16438e09617a35c0f33ea88326873eab8e Mon Sep 17 00:00:00 2001 From: binwiederhier Date: Sun, 29 Jan 2023 15:11:26 -0500 Subject: [PATCH] Associate file downloads with uploader --- server/message_cache.go | 160 +++++++++++++++++++++++----------------- server/server.go | 47 +++++++++--- server/server.yml | 3 + server/server_test.go | 9 +-- 4 files changed, 134 insertions(+), 85 deletions(-) diff --git a/server/message_cache.go b/server/message_cache.go index f2b977c..eda0c80 100644 --- a/server/message_cache.go +++ b/server/message_cache.go @@ -16,6 +16,7 @@ import ( var ( errUnexpectedMessageType = errors.New("unexpected message type") + errMessageNotFound = errors.New("message not found") ) // Messages cache @@ -60,7 +61,12 @@ const ( deleteMessageQuery = `DELETE FROM messages WHERE mid = ?` updateMessagesForTopicExpiryQuery = `UPDATE messages SET expires = ? WHERE topic = ?` selectRowIDFromMessageID = `SELECT id FROM messages WHERE mid = ?` // Do not include topic, see #336 and TestServer_PollSinceID_MultipleTopics - selectMessagesSinceTimeQuery = ` + selectMessagesByIDQuery = ` + SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding + FROM messages + WHERE mid = ? + ` + selectMessagesSinceTimeQuery = ` SELECT mid, time, expires, topic, message, title, priority, tags, click, icon, actions, attachment_name, attachment_type, attachment_size, attachment_expires, attachment_url, sender, user, encoding FROM messages WHERE topic = ? AND time >= ? AND published = 1 @@ -448,6 +454,18 @@ func (c *messageCache) MessagesExpired() ([]string, error) { return ids, nil } +func (c *messageCache) Message(id string) (*message, error) { + rows, err := c.db.Query(selectMessagesByIDQuery, id) + if err != nil { + return nil, err + } + if !rows.Next() { + return nil, errMessageNotFound + } + defer rows.Close() + return readMessage(rows) +} + func (c *messageCache) MarkPublished(m *message) error { _, err := c.db.Exec(updateMessagePublishedQuery, m.ID) return err @@ -600,75 +618,11 @@ func readMessages(rows *sql.Rows) ([]*message, error) { defer rows.Close() messages := make([]*message, 0) for rows.Next() { - var timestamp, expires, attachmentSize, attachmentExpires int64 - var priority int - var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string - err := rows.Scan( - &id, - ×tamp, - &expires, - &topic, - &msg, - &title, - &priority, - &tagsStr, - &click, - &icon, - &actionsStr, - &attachmentName, - &attachmentType, - &attachmentSize, - &attachmentExpires, - &attachmentURL, - &sender, - &user, - &encoding, - ) + m, err := readMessage(rows) if err != nil { return nil, err } - var tags []string - if tagsStr != "" { - tags = strings.Split(tagsStr, ",") - } - var actions []*action - if actionsStr != "" { - if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil { - return nil, err - } - } - senderIP, err := netip.ParseAddr(sender) - if err != nil { - senderIP = netip.Addr{} // if no IP stored in database, return invalid address - } - var att *attachment - if attachmentName != "" && attachmentURL != "" { - att = &attachment{ - Name: attachmentName, - Type: attachmentType, - Size: attachmentSize, - Expires: attachmentExpires, - URL: attachmentURL, - } - } - messages = append(messages, &message{ - ID: id, - Time: timestamp, - Expires: expires, - Event: messageEvent, - Topic: topic, - Message: msg, - Title: title, - Priority: priority, - Tags: tags, - Click: click, - Icon: icon, - Actions: actions, - Attachment: att, - Sender: senderIP, // Must parse assuming database must be correct - User: user, - Encoding: encoding, - }) + messages = append(messages, m) } if err := rows.Err(); err != nil { return nil, err @@ -676,6 +630,78 @@ func readMessages(rows *sql.Rows) ([]*message, error) { return messages, nil } +func readMessage(rows *sql.Rows) (*message, error) { + var timestamp, expires, attachmentSize, attachmentExpires int64 + var priority int + var id, topic, msg, title, tagsStr, click, icon, actionsStr, attachmentName, attachmentType, attachmentURL, sender, user, encoding string + err := rows.Scan( + &id, + ×tamp, + &expires, + &topic, + &msg, + &title, + &priority, + &tagsStr, + &click, + &icon, + &actionsStr, + &attachmentName, + &attachmentType, + &attachmentSize, + &attachmentExpires, + &attachmentURL, + &sender, + &user, + &encoding, + ) + if err != nil { + return nil, err + } + var tags []string + if tagsStr != "" { + tags = strings.Split(tagsStr, ",") + } + var actions []*action + if actionsStr != "" { + if err := json.Unmarshal([]byte(actionsStr), &actions); err != nil { + return nil, err + } + } + senderIP, err := netip.ParseAddr(sender) + if err != nil { + senderIP = netip.Addr{} // if no IP stored in database, return invalid address + } + var att *attachment + if attachmentName != "" && attachmentURL != "" { + att = &attachment{ + Name: attachmentName, + Type: attachmentType, + Size: attachmentSize, + Expires: attachmentExpires, + URL: attachmentURL, + } + } + return &message{ + ID: id, + Time: timestamp, + Expires: expires, + Event: messageEvent, + Topic: topic, + Message: msg, + Title: title, + Priority: priority, + Tags: tags, + Click: click, + Icon: icon, + Actions: actions, + Attachment: att, + Sender: senderIP, // Must parse assuming database must be correct + User: user, + Encoding: encoding, + }, nil +} + func (c *messageCache) Close() error { return c.db.Close() } diff --git a/server/server.go b/server/server.go index ddd7556..d4c1573 100644 --- a/server/server.go +++ b/server/server.go @@ -37,7 +37,8 @@ import ( /* - HIGH Rate limiting: Sensitive endpoints (account/login/change-password/...) -- HIGH Stripe payment methods +- HIGH Docs +- Large uploads for higher tiers (nginx config!) - MEDIUM: Test new token endpoints & never-expiring token - MEDIUM: Make sure account endpoints make sense for admins - MEDIUM: Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben) @@ -498,6 +499,9 @@ func (s *Server) handleDocs(w http.ResponseWriter, r *http.Request, _ *visitor) return nil } +// handleFile processes the download of attachment files. The method handles GET and HEAD requests against a file. +// Before streaming the file to a client, it locates uploader (m.Sender or m.User) in the message cache, so it +// can associate the download bandwidth with the uploader. func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) error { if s.config.AttachmentCacheDir == "" { return errHTTPInternalError @@ -512,23 +516,42 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor) if err != nil { return errHTTPNotFound } - if r.Method == http.MethodGet { - if !v.BandwidthAllowed(stat.Size()) { - return errHTTPTooManyRequestsLimitAttachmentBandwidth - } - } - w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size())) w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests - if r.Method == http.MethodGet { - f, err := os.Open(file) + w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size())) + if r.Method == http.MethodHead { + return nil + } + // Find message in database, and associate bandwidth to the uploader user + // This is an easy way to + // - avoid abuse (e.g. 1 uploader, 1k downloaders) + // - and also uses the higher bandwidth limits of a paying user + m, err := s.messageCache.Message(messageID) + if err == errMessageNotFound { + return errHTTPNotFound + } else if err != nil { + return err + } + bandwidthVisitor := v + if s.userManager != nil && m.User != "" { + u, err := s.userManager.UserByID(m.User) if err != nil { return err } - defer f.Close() - _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f) + bandwidthVisitor = s.visitor(v.IP(), u) + } else if m.Sender != netip.IPv4Unspecified() { + bandwidthVisitor = s.visitor(m.Sender, nil) + } + if !bandwidthVisitor.BandwidthAllowed(stat.Size()) { + return errHTTPTooManyRequestsLimitAttachmentBandwidth + } + // Actually send file + f, err := os.Open(file) + if err != nil { return err } - return nil + defer f.Close() + _, err = io.Copy(util.NewContentTypeWriter(w, r.URL.Path), f) + return err } func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error { diff --git a/server/server.yml b/server/server.yml index 76198f9..b607110 100644 --- a/server/server.yml +++ b/server/server.yml @@ -80,6 +80,8 @@ # - auth-file is the SQLite user/access database; it is created automatically if it doesn't already exist # - auth-default-access defines the default/fallback access if no access control entry is found; it can be # set to "read-write" (default), "read-only", "write-only" or "deny-all". +# - auth-startup-queries allows you to run commands when the database is initialized, e.g. to enable +# WAL mode. This is similar to cache-startup-queries. See above for details. # # Debian/RPM package users: # Use /var/lib/ntfy/user.db as user database to avoid permission issues. The package @@ -91,6 +93,7 @@ # # auth-file: # auth-default-access: "read-write" +# auth-startup-queries: # If set, the X-Forwarded-For header is used to determine the visitor IP address # instead of the remote address of the connection. diff --git a/server/server_test.go b/server/server_test.go index c3bb9cf..06fe482 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -1543,6 +1543,7 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) { content := util.RandomString(5000) // > 4096 c := newTestConfigWithAuthFile(t) + c.VisitorAttachmentDailyBandwidthLimit = 1000 // Much lower than tier bandwidth! s := newTestServer(t, c) // Create tier with certain limits @@ -1566,16 +1567,12 @@ func TestServer_PublishAttachmentWithTierBasedBandwidthLimit(t *testing.T) { msg := toMessage(t, rr.Body.String()) // Retrieve it (first time succeeds) - rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{ - "Authorization": util.BasicAuth("phil", "phil"), - }) + rr = request(t, s, "GET", "/file/"+msg.ID, content, nil) // File downloads do not send auth headers!! require.Equal(t, 200, rr.Code) require.Equal(t, content, rr.Body.String()) // Retrieve it AGAIN (fails, due to bandwidth limit) - rr = request(t, s, "GET", "/file/"+msg.ID, content, map[string]string{ - "Authorization": util.BasicAuth("phil", "phil"), - }) + rr = request(t, s, "GET", "/file/"+msg.ID, content, nil) require.Equal(t, 429, rr.Code) }