Kill existing subscribers when topic is reserved

This commit is contained in:
binwiederhier 2023-01-23 14:05:41 -05:00
parent e82a2e518c
commit bce71cb196
5 changed files with 169 additions and 36 deletions

View file

@ -38,11 +38,13 @@ import (
TODO
--
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
- Reservation icons (UI)
- reservation table delete button: dialog "keep or delete messages?"
- UI: Flickering upgrade banner when logging in
- JS constants
races:
- v.user --> see publishSyncEventAsync() test
@ -63,11 +65,6 @@ Limits & rate limiting:
Make sure account endpoints make sense for admins
UI:
-
- reservation table delete button: dialog "keep or delete messages?"
- flicker of upgrade banner
- JS constants
Sync:
- sync problems with "deleteAfter=0" and "displayName="
@ -359,7 +356,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.WriteHeader(httpErr.HTTPCode)
io.WriteString(w, httpErr.JSON()+"\n")
}
@ -461,7 +458,7 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor)
unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
if unifiedpush {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
_, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
return err
}
@ -538,7 +535,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
}
}
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if r.Method == http.MethodGet {
f, err := os.Open(file)
if err != nil {
@ -969,14 +966,16 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
}
return nil
}
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll {
return s.sendOldMessages(topics, since, scheduled, v, sub)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
}
defer func() {
for i, subscriberID := range subscriberIDs {
@ -991,6 +990,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
}
for {
select {
case <-ctx.Done():
return nil
case <-r.Context().Done():
return nil
case <-time.After(s.config.KeepaliveInterval):
@ -1033,8 +1034,20 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
return err
}
defer conn.Close()
// Subscription connections can be canceled externally, see topic.CancelSubscribers
subscriberContext, cancel := context.WithCancel(context.Background())
defer cancel()
// Use errgroup to run WebSocket reader and writer in Go routines
var wlock sync.Mutex
g, ctx := errgroup.WithContext(context.Background())
g, gctx := errgroup.WithContext(context.Background())
g.Go(func() error {
<-subscriberContext.Done()
log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r))
conn.Close()
return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"}
})
g.Go(func() error {
pongWait := s.config.KeepaliveInterval + wsPongWait
conn.SetReadLimit(wsReadLimit)
@ -1050,6 +1063,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
if err != nil {
return err
}
select {
case <-gctx.Done():
return nil
default:
}
}
})
g.Go(func() error {
@ -1064,7 +1082,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
for {
select {
case <-ctx.Done():
case <-gctx.Done():
return nil
case <-time.After(s.config.KeepaliveInterval):
v.Keepalive()
@ -1085,13 +1103,13 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
return conn.WriteJSON(msg)
}
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if poll {
return s.sendOldMessages(topics, since, scheduled, v, sub)
}
subscriberIDs := make([]int, 0)
for _, t := range topics {
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
}
defer func() {
for i, subscriberID := range subscriberIDs {
@ -1193,11 +1211,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
if len(parts) < 2 {
return nil, errHTTPBadRequestTopicInvalid
}
topics, err := s.topicsFromIDs(parts[1])
if err != nil {
return nil, err
}
return topics[0], nil
return s.topicFromID(parts[1])
}
func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
@ -1232,6 +1246,14 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
return topics, nil
}
func (s *Server) topicFromID(id string) (*topic, error) {
topics, err := s.topicsFromIDs(id)
if err != nil {
return nil, err
}
return topics[0], nil
}
func (s *Server) execManager() {
log.Debug("Manager: Starting")
defer log.Debug("Manager: Finished")

View file

@ -2,7 +2,6 @@ package server
import (
"encoding/json"
"errors"
"heckel.io/ntfy/log"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
@ -331,6 +330,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
if v.user.Tier == nil {
return errHTTPUnauthorized
}
// CHeck if we are allowed to reserve this topic
if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil {
return errHTTPConflictTopicReserved
}
@ -346,9 +346,16 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
return errHTTPTooManyRequestsLimitReservations
}
}
// Actually add the reservation
if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil {
return err
}
// Kill existing subscribers
t, err := s.topicFromID(req.Topic)
if err != nil {
return err
}
t.CancelSubscribers(v.user.ID)
return s.writeJSON(w, newSuccessResponse())
}
@ -402,13 +409,10 @@ func (s *Server) publishSyncEvent(v *visitor) error {
return nil
}
log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic)
topics, err := s.topicsFromIDs(v.user.SyncTopic)
syncTopic, err := s.topicFromID(v.user.SyncTopic)
if err != nil {
return err
} else if len(topics) == 0 {
return errors.New("cannot retrieve sync topic")
}
syncTopic := topics[0]
messageBytes, err := json.Marshal(&apiAccountSyncTopicResponse{Event: syncTopicAccountSyncEvent})
if err != nil {
return err

View file

@ -496,3 +496,72 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
rr = request(t, s, "POST", "/mytopic", `Howdy`, nil)
require.Equal(t, 403, rr.Code)
}
func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
conf := newTestConfigWithAuthFile(t)
conf.AuthDefault = user.PermissionReadWrite
conf.EnableSignup = true
s := newTestServer(t, conf)
// Create user with tier
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
require.Equal(t, 200, rr.Code)
require.Nil(t, s.userManager.CreateTier(&user.Tier{
Code: "pro",
MessagesLimit: 20,
ReservationsLimit: 2,
}))
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
// Subscribe anonymously
anonCh, userCh := make(chan bool), make(chan bool)
go func() {
rr := request(t, s, "GET", "/mytopic/json", ``, nil)
require.Equal(t, 200, rr.Code)
messages := toMessages(t, rr.Body.String())
require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message!
require.Equal(t, "open", messages[0].Event)
require.Equal(t, "message before reservation", messages[1].Message)
anonCh <- true
}()
// Subscribe with user
go func() {
rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"),
})
require.Equal(t, 200, rr.Code)
messages := toMessages(t, rr.Body.String())
require.Equal(t, 3, len(messages))
require.Equal(t, "open", messages[0].Event)
require.Equal(t, "message before reservation", messages[1].Message)
require.Equal(t, "message after reservation", messages[2].Message)
userCh <- true
}()
// Publish message (before reservation)
time.Sleep(700 * time.Millisecond) // Wait for subscribers
rr = request(t, s, "POST", "/mytopic", "message before reservation", nil)
require.Equal(t, 200, rr.Code)
time.Sleep(700 * time.Millisecond) // Wait for subscribers to receive message
// Reserve a topic
rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"),
})
require.Equal(t, 200, rr.Code)
// Everyone but phil should be killed
<-anonCh
// Publish a message
rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{
"Authorization": util.BasicAuth("phil", "mypass"),
})
require.Equal(t, 200, rr.Code)
// Kill user Go routine
s.topics["mytopic"].CancelSubscribers("<invalid>")
<-userCh
}

View file

@ -10,10 +10,16 @@ import (
// can publish a message
type topic struct {
ID string
subscribers map[int]subscriber
subscribers map[int]*topicSubscriber
mu sync.Mutex
}
type topicSubscriber struct {
userID string // User ID associated with this subscription, may be empty
subscriber subscriber
cancel func()
}
// subscriber is a function that is called for every new message on a topic
type subscriber func(v *visitor, msg *message) error
@ -21,16 +27,20 @@ type subscriber func(v *visitor, msg *message) error
func newTopic(id string) *topic {
return &topic{
ID: id,
subscribers: make(map[int]subscriber),
subscribers: make(map[int]*topicSubscriber),
}
}
// Subscribe subscribes to this topic
func (t *topic) Subscribe(s subscriber) int {
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
t.mu.Lock()
defer t.mu.Unlock()
subscriberID := rand.Int()
t.subscribers[subscriberID] = s
t.subscribers[subscriberID] = &topicSubscriber{
userID: userID, // May be empty
subscriber: s,
cancel: cancel,
}
return subscriberID
}
@ -56,7 +66,7 @@ func (t *topic) Publish(v *visitor, m *message) error {
if err := s(v, m); err != nil {
log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m))
}
}(s)
}(s.subscriber)
}
} else {
log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m))
@ -72,13 +82,29 @@ func (t *topic) SubscribersCount() int {
return len(t.subscribers)
}
// subscribersCopy returns a shallow copy of the subscribers map
func (t *topic) subscribersCopy() map[int]subscriber {
// CancelSubscribers calls the cancel function for all subscribers, forcing
func (t *topic) CancelSubscribers(exceptUserID string) {
t.mu.Lock()
defer t.mu.Unlock()
subscribers := make(map[int]subscriber)
for k, v := range t.subscribers {
subscribers[k] = v
for _, s := range t.subscribers {
if s.userID != exceptUserID {
log.Trace("Canceling subscriber %s", s.userID)
s.cancel()
}
}
}
// subscribersCopy returns a shallow copy of the subscribers map
func (t *topic) subscribersCopy() map[int]*topicSubscriber {
t.mu.Lock()
defer t.mu.Unlock()
subscribers := make(map[int]*topicSubscriber)
for k, sub := range t.subscribers {
subscribers[k] = &topicSubscriber{
userID: sub.userID,
subscriber: sub.subscriber,
cancel: sub.cancel,
}
}
return subscribers
}

View file

@ -228,12 +228,24 @@ func (v *visitor) ResetStats() {
}
}
// SetUser sets the visitors user to the given value
func (v *visitor) SetUser(u *user.User) {
v.mu.Lock()
defer v.mu.Unlock()
v.user = u
}
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
// an empty string is returned.
func (v *visitor) MaybeUserID() string {
v.mu.Lock()
defer v.mu.Unlock()
if v.user != nil {
return v.user.ID
}
return ""
}
func (v *visitor) Limits() *visitorLimits {
v.mu.Lock()
defer v.mu.Unlock()