Works
This commit is contained in:
parent
3eeeac2c13
commit
346d8d7967
8 changed files with 141 additions and 15 deletions
|
@ -585,9 +585,9 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
|
||||||
return writeMatrixDiscoveryResponse(w)
|
return writeMatrixDiscoveryResponse(w)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
|
func (s *Server) handlePublishInternal(r *http.Request, v *visitor) (*message, error) {
|
||||||
t := fromContext[topic](r, contextTopic)
|
t := fromContext[*topic](r, contextTopic)
|
||||||
vrate := fromContext[visitor](r, contextRateVisitor)
|
vrate := fromContext[*visitor](r, contextRateVisitor)
|
||||||
body, err := util.Peek(r.Body, s.config.MessageLimit)
|
body, err := util.Peek(r.Body, s.config.MessageLimit)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -670,7 +670,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
m, err := s.handlePublishWithoutResponse(r, v)
|
m, err := s.handlePublishInternal(r, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -678,10 +678,14 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
func (s *Server) handlePublishMatrix(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
_, err := s.handlePublishWithoutResponse(r, v)
|
_, err := s.handlePublishInternal(r, v)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if e, ok := err.(*errHTTP); ok && e.HTTPCode == errHTTPInsufficientStorageUnifiedPush.HTTPCode {
|
if e, ok := err.(*errHTTP); ok && e.HTTPCode == errHTTPInsufficientStorageUnifiedPush.HTTPCode {
|
||||||
return writeMatrixResponse(w, e.rejectedPushKey)
|
topic := fromContext[*topic](r, contextTopic)
|
||||||
|
pushKey := fromContext[string](r, contextMatrixPushKey)
|
||||||
|
if time.Since(topic.LastAccess()) > matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter {
|
||||||
|
return writeMatrixResponse(w, pushKey)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -1011,6 +1015,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||||
if poll {
|
if poll {
|
||||||
|
for _, t := range topics {
|
||||||
|
t.Keepalive()
|
||||||
|
}
|
||||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
@ -1037,7 +1044,12 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
case <-r.Context().Done():
|
case <-r.Context().Done():
|
||||||
return nil
|
return nil
|
||||||
case <-time.After(s.config.KeepaliveInterval):
|
case <-time.After(s.config.KeepaliveInterval):
|
||||||
logvr(v, r).Tag(tagSubscribe).Trace("Sending keepalive message")
|
ev := logvr(v, r).Tag(tagSubscribe)
|
||||||
|
if len(topics) == 1 {
|
||||||
|
ev.With(topics[0]).Trace("Sending keepalive message to %s", topics[0].ID)
|
||||||
|
} else {
|
||||||
|
ev.Trace("Sending keepalive message to %d topics", len(topics))
|
||||||
|
}
|
||||||
v.Keepalive()
|
v.Keepalive()
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
t.Keepalive()
|
t.Keepalive()
|
||||||
|
@ -1154,6 +1166,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
}
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
if poll {
|
if poll {
|
||||||
|
for _, t := range topics {
|
||||||
|
t.Keepalive()
|
||||||
|
}
|
||||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||||
}
|
}
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
|
|
|
@ -2,8 +2,8 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"heckel.io/ntfy/log"
|
"heckel.io/ntfy/log"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) execManager() {
|
func (s *Server) execManager() {
|
||||||
|
@ -39,13 +39,13 @@ func (s *Server) execManager() {
|
||||||
ev := log.Tag(tagManager).With(t)
|
ev := log.Tag(tagManager).With(t)
|
||||||
if t.Stale() {
|
if t.Stale() {
|
||||||
if ev.IsTrace() {
|
if ev.IsTrace() {
|
||||||
ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, lastAccess.Format(time.RFC822))
|
ev.Trace("- topic %s: Deleting stale topic (%d subscribers, accessed %s)", t.ID, subs, util.FormatTime(lastAccess))
|
||||||
}
|
}
|
||||||
emptyTopics++
|
emptyTopics++
|
||||||
delete(s.topics, t.ID)
|
delete(s.topics, t.ID)
|
||||||
} else {
|
} else {
|
||||||
if ev.IsTrace() {
|
if ev.IsTrace() {
|
||||||
ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, lastAccess.Format(time.RFC822))
|
ev.Trace("- topic %s: %d subscribers, accessed %s", t.ID, subs, util.FormatTime(lastAccess))
|
||||||
}
|
}
|
||||||
subscribers += subs
|
subscribers += subs
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Matrix Push Gateway / UnifiedPush / ntfy integration:
|
// Matrix Push Gateway / UnifiedPush / ntfy integration:
|
||||||
|
@ -71,6 +72,14 @@ type matrixResponse struct {
|
||||||
Rejected []string `json:"rejected"`
|
Rejected []string `json:"rejected"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter is the time after which a Matrix response
|
||||||
|
// will return an HTTP 200 with the push key (i.e. "rejected":["<pushkey>"]}), if no rate visitor has been set on
|
||||||
|
// the topic. Rejecting the push key will instruct the Matrix server to invalidate the pushkey and stop sending
|
||||||
|
// messages to it. See https://spec.matrix.org/v1.6/push-gateway-api/
|
||||||
|
matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter = 12 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
// errMatrixPushkeyRejected represents an error when handing Matrix gateway messages
|
// errMatrixPushkeyRejected represents an error when handing Matrix gateway messages
|
||||||
//
|
//
|
||||||
// If the push key is set, the app server will remove it and will never send messages using the same
|
// If the push key is set, the app server will remove it and will never send messages using the same
|
||||||
|
@ -126,7 +135,9 @@ func newRequestFromMatrixJSON(r *http.Request, baseURL string, messageLimit int)
|
||||||
if r.Header.Get("X-Forwarded-For") != "" {
|
if r.Header.Get("X-Forwarded-For") != "" {
|
||||||
newRequest.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For"))
|
newRequest.Header.Set("X-Forwarded-For", r.Header.Get("X-Forwarded-For"))
|
||||||
}
|
}
|
||||||
newRequest.Header.Set("X-Matrix-Pushkey", pushKey)
|
newRequest = withContext(newRequest, map[contextKey]any{
|
||||||
|
contextMatrixPushKey: pushKey,
|
||||||
|
})
|
||||||
return newRequest, nil
|
return newRequest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -11,6 +11,7 @@ type contextKey int
|
||||||
const (
|
const (
|
||||||
contextRateVisitor contextKey = iota + 2586
|
contextRateVisitor contextKey = iota + 2586
|
||||||
contextTopic
|
contextTopic
|
||||||
|
contextMatrixPushKey
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||||
|
|
|
@ -1172,6 +1172,56 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) {
|
||||||
require.Equal(t, 400, response.Code)
|
require.Equal(t, 400, response.Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_PublishAndExpungeTopicAfter16Hours(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s := newTestServer(t, newTestConfig(t))
|
||||||
|
|
||||||
|
subFn := func(v *visitor, msg *message) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Publish and check last access
|
||||||
|
response := request(t, s, "POST", "/mytopic", "test", map[string]string{
|
||||||
|
"Cache": "no",
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2)
|
||||||
|
require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2)
|
||||||
|
|
||||||
|
// Topic won't get pruned
|
||||||
|
s.execManager()
|
||||||
|
require.NotNil(t, s.topics["mytopic"])
|
||||||
|
|
||||||
|
// Fudge with last access, but subscribe, and see that it won't get pruned (because of subscriber)
|
||||||
|
subID := s.topics["mytopic"].Subscribe(subFn, "", func() {})
|
||||||
|
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
|
||||||
|
s.execManager()
|
||||||
|
require.NotNil(t, s.topics["mytopic"])
|
||||||
|
|
||||||
|
// It'll finally get pruned now that there are no subscribers and last access is 17 hours ago
|
||||||
|
s.topics["mytopic"].Unsubscribe(subID)
|
||||||
|
s.execManager()
|
||||||
|
require.Nil(t, s.topics["mytopic"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_TopicKeepaliveOnPoll(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
s := newTestServer(t, newTestConfig(t))
|
||||||
|
|
||||||
|
// Create topic by polling once
|
||||||
|
response := request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
|
||||||
|
// Mess with last access time
|
||||||
|
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
|
||||||
|
|
||||||
|
// Poll again and check keepalive time
|
||||||
|
response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil)
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
require.True(t, s.topics["mytopic"].lastAccess.Unix() >= time.Now().Unix()-2)
|
||||||
|
require.True(t, s.topics["mytopic"].lastAccess.Unix() <= time.Now().Unix()+2)
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_UnifiedPushDiscovery(t *testing.T) {
|
func TestServer_UnifiedPushDiscovery(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfig(t))
|
s := newTestServer(t, newTestConfig(t))
|
||||||
response := request(t, s, "GET", "/mytopic?up=1", "", nil)
|
response := request(t, s, "GET", "/mytopic?up=1", "", nil)
|
||||||
|
@ -1301,6 +1351,32 @@ func TestServer_MatrixGateway_Push_Failure_NoSubscriber(t *testing.T) {
|
||||||
require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code)
|
require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_MatrixGateway_Push_Failure_NoSubscriber_After13Hours(t *testing.T) {
|
||||||
|
c := newTestConfig(t)
|
||||||
|
c.VisitorSubscriberRateLimiting = true
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
notification := `{"notification":{"devices":[{"pushkey":"http://127.0.0.1:12345/mytopic?up=1"}]}}`
|
||||||
|
|
||||||
|
// No success if no rate visitor set (this also creates the topic in memory
|
||||||
|
response := request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil)
|
||||||
|
require.Equal(t, 507, response.Code)
|
||||||
|
require.Equal(t, 50701, toHTTPError(t, response.Body.String()).Code)
|
||||||
|
require.Nil(t, s.topics["mytopic"].rateVisitor)
|
||||||
|
|
||||||
|
// Fake: This topic has been around for 13 hours without a rate visitor
|
||||||
|
s.topics["mytopic"].lastAccess = time.Now().Add(-13 * time.Hour)
|
||||||
|
|
||||||
|
// Same request should now return HTTP 200 with a rejected pushkey
|
||||||
|
response = request(t, s, "POST", "/_matrix/push/v1/notify", notification, nil)
|
||||||
|
require.Equal(t, 200, response.Code)
|
||||||
|
require.Equal(t, `{"rejected":["http://127.0.0.1:12345/mytopic?up=1"]}`, strings.TrimSpace(response.Body.String()))
|
||||||
|
|
||||||
|
// Slightly unrelated: Test that topic is pruned after 16 hours
|
||||||
|
s.topics["mytopic"].lastAccess = time.Now().Add(-17 * time.Hour)
|
||||||
|
s.execManager()
|
||||||
|
require.Nil(t, s.topics["mytopic"])
|
||||||
|
}
|
||||||
|
|
||||||
func TestServer_MatrixGateway_Push_Failure_InvalidPushkey(t *testing.T) {
|
func TestServer_MatrixGateway_Push_Failure_InvalidPushkey(t *testing.T) {
|
||||||
s := newTestServer(t, newTestConfig(t))
|
s := newTestServer(t, newTestConfig(t))
|
||||||
notification := `{"notification":{"devices":[{"pushkey":"http://wrong-base-url.com/mytopic?up=1"}]}}`
|
notification := `{"notification":{"devices":[{"pushkey":"http://wrong-base-url.com/mytopic?up=1"}]}}`
|
||||||
|
|
|
@ -2,13 +2,18 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"heckel.io/ntfy/log"
|
"heckel.io/ntfy/log"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
topicExpiryDuration = 6 * time.Hour
|
// topicExpungeAfter defines how long a topic is active before it is removed from memory.
|
||||||
|
//
|
||||||
|
// This must be larger than matrixRejectPushKeyForUnifiedPushTopicWithoutRateVisitorAfter to give
|
||||||
|
// time for more requests to come in, so that we can send a {"rejected":["<pushkey>"]} response back.
|
||||||
|
topicExpungeAfter = 16 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
// topic represents a channel to which subscribers can subscribe, and publishers
|
// topic represents a channel to which subscribers can subscribe, and publishers
|
||||||
|
@ -59,7 +64,13 @@ func (t *topic) Stale() bool {
|
||||||
if t.rateVisitor != nil && !t.rateVisitor.Stale() {
|
if t.rateVisitor != nil && !t.rateVisitor.Stale() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpiryDuration
|
return len(t.subscribers) == 0 && time.Since(t.lastAccess) > topicExpungeAfter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *topic) LastAccess() time.Time {
|
||||||
|
t.mu.RLock()
|
||||||
|
defer t.mu.RUnlock()
|
||||||
|
return t.lastAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *topic) SetRateVisitor(v *visitor) {
|
func (t *topic) SetRateVisitor(v *visitor) {
|
||||||
|
@ -148,6 +159,7 @@ func (t *topic) Context() log.Context {
|
||||||
fields := map[string]any{
|
fields := map[string]any{
|
||||||
"topic": t.ID,
|
"topic": t.ID,
|
||||||
"topic_subscribers": len(t.subscribers),
|
"topic_subscribers": len(t.subscribers),
|
||||||
|
"topic_last_access": util.FormatTime(t.lastAccess),
|
||||||
}
|
}
|
||||||
if t.rateVisitor != nil {
|
if t.rateVisitor != nil {
|
||||||
for k, v := range t.rateVisitor.Context() {
|
for k, v := range t.rateVisitor.Context() {
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTopic_CancelSubscribers(t *testing.T) {
|
func TestTopic_CancelSubscribers(t *testing.T) {
|
||||||
|
@ -28,3 +29,13 @@ func TestTopic_CancelSubscribers(t *testing.T) {
|
||||||
require.True(t, canceled1.Load())
|
require.True(t, canceled1.Load())
|
||||||
require.False(t, canceled2.Load())
|
require.False(t, canceled2.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTopic_Keepalive(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
to := newTopic("mytopic")
|
||||||
|
to.lastAccess = time.Now().Add(-1 * time.Hour)
|
||||||
|
to.Keepalive()
|
||||||
|
require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2)
|
||||||
|
require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2)
|
||||||
|
}
|
||||||
|
|
|
@ -107,8 +107,8 @@ func withContext(r *http.Request, ctx map[contextKey]any) *http.Request {
|
||||||
return r.WithContext(c)
|
return r.WithContext(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
func fromContext[T any](r *http.Request, key contextKey) *T {
|
func fromContext[T any](r *http.Request, key contextKey) T {
|
||||||
t, ok := r.Context().Value(key).(*T)
|
t, ok := r.Context().Value(key).(T)
|
||||||
if !ok {
|
if !ok {
|
||||||
panic(fmt.Sprintf("cannot find key %v in request context", key))
|
panic(fmt.Sprintf("cannot find key %v in request context", key))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue