This commit is contained in:
binwiederhier 2023-03-03 22:22:07 -05:00
parent 3eeeac2c13
commit 346d8d7967
8 changed files with 141 additions and 15 deletions

View file

@ -585,9 +585,9 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
return writeMatrixDiscoveryResponse(w) return writeMatrixDiscoveryResponse(w)
} }
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) { func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
t := fromContext[topic](r, contextTopic) t := fromContext[*topic](r, contextTopic)
vrate := fromContext[visitor](r, contextRateVisitor) vrate := fromContext[*visitor](r, contextRateVisitor)
body, err := util.Peek(r.Body, s.config.MessageLimit) body, err := util.Peek(r.Body, s.config.MessageLimit)
if err != nil { if err != nil {
return nil, err return nil, err
@ -670,7 +670,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
} }
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
m, err := s.handlePublishWithoutResponse(r, v) m, err := s.handlePublishInternal(r, v)
if err != nil { if err != nil {
return err return err
} }
@ -678,10 +678,14 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
} }
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error { func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
_, err := s.handlePublishWithoutResponse(r, v) _, err := s.handlePublishInternal(r, v)
if err != nil { if err != nil {
if e, ok := err.(*errHTTP); ok && e.HTTPCode == errHTTPInsufficientStorageUnifiedPush.HTTPCode { if e, ok := err.(*errHTTP); ok && e.HTTPCode == errHTTPInsufficientStorageUnifiedPush.HTTPCode {
return writeMatrixResponse(w, e.rejectedPushKey) topic := fromContext[*topic](r, contextTopic)
pushKey := fromContext[string](r, contextMatrixPushKey)
if time.Since(topic.LastAccess()) > matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter {
return writeMatrixResponse(w, pushKey)
}
} }
return err return err
} }
@ -1011,6 +1015,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll { if poll {
for _, t := range topics {
t.Keepalive()
}
return s.sendOldMessages(topics, since, scheduled, v, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
@ -1037,7 +1044,12 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
case <-r.Context().Done(): case <-r.Context().Done():
return nil return nil
case <-time.After(s.config.KeepaliveInterval): case <-time.After(s.config.KeepaliveInterval):
logvr(v, r).Tag(tagSubscribe).Trace("Sending keepalive message") ev := logvr(v, r).Tag(tagSubscribe)
if len(topics) == 1 {
ev.With(topics[0]).Trace("Sending keepalive message to %s", topics[0].ID)
} else {
ev.Trace("Sending keepalive message to %d topics", len(topics))
}
v.Keepalive() v.Keepalive()
for _, t := range topics { for _, t := range topics {
t.Keepalive() t.Keepalive()
@ -1154,6 +1166,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
} }
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if poll { if poll {
for _, t := range topics {
t.Keepalive()
}
return s.sendOldMessages(topics, since, scheduled, v, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
} }
subscriberIDs := make([]int, 0) subscriberIDs := make([]int, 0)

View file

@ -2,8 +2,8 @@ package server
import ( import (
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/util"
"strings" "strings"
"time"
) )
func (s *Server) execManager() { func (s *Server) execManager() {
@ -39,13 +39,13 @@ func (s *Server) execManager() {
ev := log.Tag(tagManager).With(t) ev := log.Tag(tagManager).With(t)
if t.Stale() { if t.Stale() {
if ev.IsTrace() { if ev.IsTrace() {
ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, lastAccess.Format(time.RFC822)) ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, util.FormatTime(lastAccess))
} }
emptyTopics++ emptyTopics++
delete(s.topics, t.ID) delete(s.topics, t.ID)
} else { } else {
if ev.IsTrace() { if ev.IsTrace() {
ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, lastAccess.Format(time.RFC822)) ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, util.FormatTime(lastAccess))
} }
subscribers += subs subscribers += subs
} }

View file

@ -8,6 +8,7 @@ import (
"io" "io"
"net/http" "net/http"
"strings" "strings"
"time"
) )
// Matrix Push Gateway / UnifiedPush / ntfy integration: // Matrix Push Gateway / UnifiedPush / ntfy integration:
@ -71,6 +72,14 @@ type matrixResponse struct {
Rejected []string `json:"rejected"` Rejected []string `json:"rejected"`
} }
const (
// matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter is the time after which a Matrix response
// will return an HTTP 200 with the push key (i.e. "rejected":["<pushkey>"]}), if no rate visitor has been set on
// the topic. Rejecting the push key will instruct the Matrix server to invalidate the pushkey and stop sending
// messages to it. See https://spec.matrix.org/v1.6/push-gateway-api/
matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter = 12 * time.Hour
)
// errMatrixPushkeyRejected represents an error when handing Matrix gateway messages // errMatrixPushkeyRejected represents an error when handing Matrix gateway messages
// //
// If the push key is set, the app server will remove it and will never send messages using the same // If the push key is set, the app server will remove it and will never send messages using the same
@ -126,7 +135,9 @@ func newRequestFromMatrixJSON(r *http.Request, baseURL string, messageLimit int)
if r.Header.Get("X-Forwarded-For") != "" { if r.Header.Get("X-Forwarded-For") != "" {
newRequest.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For")) newRequest.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For"))
} }
newRequest.Header.Set("X-Matrix-Pushkey", pushKey) newRequest = withContext(newRequest, map[contextKey]any{
contextMatrixPushKey: pushKey,
})
return newRequest, nil return newRequest, nil
} }

View file

@ -11,6 +11,7 @@ type contextKey int
const ( const (
contextRateVisitor contextKey = iota + 2586 contextRateVisitor contextKey = iota + 2586
contextTopic contextTopic
contextMatrixPushKey
) )
func (s *Server) limitRequests(next handleFunc) handleFunc { func (s *Server) limitRequests(next handleFunc) handleFunc {

View file

@ -1172,6 +1172,56 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
require.Equal(t, 400, response.Code) require.Equal(t, 400, response.Code)
} }
func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) {
t.Parallel()
s := newTestServer(t, newTestConfig(t))
subFn := func(v *visitor, msg *message) error {
return nil
}
// Publish and check last access
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
"Cache": "no",
})
require.Equal(t, 200, response.Code)
require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2)
require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2)
// Topic won't get pruned
s.execManager()
require.NotNil(t, s.topics["mytopic"])
// Fudge with last access, but subscribe, and see that it won't get pruned (because of subscriber)
subID := s.topics["mytopic"].Subscribe(subFn, "", func() {})
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
s.execManager()
require.NotNil(t, s.topics["mytopic"])
// It'll finally get pruned now that there are no subscribers and last access is 17 hours ago
s.topics["mytopic"].Unsubscribe(subID)
s.execManager()
require.Nil(t, s.topics["mytopic"])
}
func TestServer_TopicKeepaliveOnPoll(t *testing.T) {
t.Parallel()
s := newTestServer(t, newTestConfig(t))
// Create topic by polling once
response := request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
require.Equal(t, 200, response.Code)
// Mess with last access time
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
// Poll again and check keepalive time
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
require.Equal(t, 200, response.Code)
require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2)
require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2)
}
func TestServer_UnifiedPushDiscovery(t *testing.T) { func TestServer_UnifiedPushDiscovery(t *testing.T) {
s := newTestServer(t, newTestConfig(t)) s := newTestServer(t, newTestConfig(t))
response := request(t, s, "GET", "/mytopic?up=1", "", nil) response := request(t, s, "GET", "/mytopic?up=1", "", nil)
@ -1301,6 +1351,32 @@ func TestServer_MatrixGateway_Push_Failure_NoSubscriber(t *testing.T) {
require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code) require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code)
} }
func TestServer_MatrixGateway_Push_Failure_NoSubscriber_After13Hours(t *testing.T) {
c := newTestConfig(t)
c.VisitorSubscriberRateLimiting = true
s := newTestServer(t, c)
notification := `{"notification":{"devices":[{"pushkey":"http://127.0.0.1:12345/mytopic?up=1"}]}}`
// No success if no rate visitor set (this also creates the topic in memory
response := request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil)
require.Equal(t, 507, response.Code)
require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code)
require.Nil(t, s.topics["mytopic"].rateVisitor)
// Fake: This topic has been around for 13 hours without a rate visitor
s.topics["mytopic"].lastAccess = time.Now().Add(-13 * time.Hour)
// Same request should now return HTTP 200 with a rejected pushkey
response = request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil)
require.Equal(t, 200, response.Code)
require.Equal(t, `{"rejected":["http://127.0.0.1:12345/mytopic?up=1"]}`, strings.TrimSpace(response.Body.String()))
// Slightly unrelated: Test that topic is pruned after 16 hours
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
s.execManager()
require.Nil(t, s.topics["mytopic"])
}
func TestServer_MatrixGateway_Push_Failure_InvalidPushkey(t *testing.T) { func TestServer_MatrixGateway_Push_Failure_InvalidPushkey(t *testing.T) {
s := newTestServer(t, newTestConfig(t)) s := newTestServer(t, newTestConfig(t))
notification := `{"notification":{"devices":[{"pushkey":"http://wrong-base-url.com/mytopic?up=1"}]}}` notification := `{"notification":{"devices":[{"pushkey":"http://wrong-base-url.com/mytopic?up=1"}]}}`

View file

@ -2,13 +2,18 @@ package server
import ( import (
"heckel.io/ntfy/log" "heckel.io/ntfy/log"
"heckel.io/ntfy/util"
"math/rand" "math/rand"
"sync" "sync"
"time" "time"
) )
const ( const (
topicExpiryDuration = 6 * time.Hour // topicExpungeAfter defines how long a topic is active before it is removed from memory.
//
// This must be larger than matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter to give
// time for more requests to come in, so that we can send a {"rejected":["<pushkey>"]} response back.
topicExpungeAfter = 16 * time.Hour
) )
// topic represents a channel to which subscribers can subscribe, and publishers // topic represents a channel to which subscribers can subscribe, and publishers
@ -59,7 +64,13 @@ func (t *topic) Stale() bool {
if t.rateVisitor != nil && !t.rateVisitor.Stale() { if t.rateVisitor != nil && !t.rateVisitor.Stale() {
return false return false
} }
return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpiryDuration return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpungeAfter
}
func (t *topic) LastAccess() time.Time {
t.mu.RLock()
defer t.mu.RUnlock()
return t.lastAccess
} }
func (t *topic) SetRateVisitor(v *visitor) { func (t *topic) SetRateVisitor(v *visitor) {
@ -148,6 +159,7 @@ func (t *topic) Context() log.Context {
fields := map[string]any{ fields := map[string]any{
"topic": t.ID, "topic": t.ID,
"topic_subscribers": len(t.subscribers), "topic_subscribers": len(t.subscribers),
"topic_last_access": util.FormatTime(t.lastAccess),
} }
if t.rateVisitor != nil { if t.rateVisitor != nil {
for k, v := range t.rateVisitor.Context() { for k, v := range t.rateVisitor.Context() {

View file

@ -4,6 +4,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time"
) )
func TestTopic_CancelSubscribers(t *testing.T) { func TestTopic_CancelSubscribers(t *testing.T) {
@ -28,3 +29,13 @@ func TestTopic_CancelSubscribers(t *testing.T) {
require.True(t, canceled1.Load()) require.True(t, canceled1.Load())
require.False(t, canceled2.Load()) require.False(t, canceled2.Load())
} }
func TestTopic_Keepalive(t *testing.T) {
t.Parallel()
to := newTopic("mytopic")
to.lastAccess = time.Now().Add(-1 * time.Hour)
to.Keepalive()
require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
}

View file

@ -107,8 +107,8 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
return r.WithContext(c) return r.WithContext(c)
} }
func fromContext[T any](r *http.Request, key contextKey) *T { func fromContext[T any](r *http.Request, key contextKey) T {
t, ok := r.Context().Value(key).(*T) t, ok := r.Context().Value(key).(T)
if !ok { if !ok {
panic(fmt.Sprintf("cannot find key %v in request context", key)) panic(fmt.Sprintf("cannot find key %v in request context", key))
} }