Kill existing subscribers when topic is reserved
This commit is contained in:
parent
e82a2e518c
commit
bce71cb196
5 changed files with 169 additions and 36 deletions
|
@ -38,11 +38,13 @@ import (
|
|||
TODO
|
||||
--
|
||||
|
||||
- Reservation: Kill existing subscribers when topic is reserved (deadcade)
|
||||
- Rate limiting: Sensitive endpoints (account/login/change-password/...)
|
||||
- Reservation (UI): Show "This topic is reserved" error message when trying to reserve a reserved topic (Thorben)
|
||||
- Reservation (UI): Ask for confirmation when removing reservation (deadcade)
|
||||
- Reservation icons (UI)
|
||||
- reservation table delete button: dialog "keep or delete messages?"
|
||||
- UI: Flickering upgrade banner when logging in
|
||||
- JS constants
|
||||
|
||||
races:
|
||||
- v.user --> see publishSyncEventAsync() test
|
||||
|
@ -63,11 +65,6 @@ Limits & rate limiting:
|
|||
Make sure account endpoints make sense for admins
|
||||
|
||||
|
||||
UI:
|
||||
-
|
||||
- reservation table delete button: dialog "keep or delete messages?"
|
||||
- flicker of upgrade banner
|
||||
- JS constants
|
||||
Sync:
|
||||
- sync problems with "deleteAfter=0" and "displayName="
|
||||
|
||||
|
@ -359,7 +356,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
|||
log.Info("%s Connection closed with HTTP %d (ntfy error %d): %s", logHTTPPrefix(v, r), httpErr.HTTPCode, httpErr.Code, err.Error())
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.WriteHeader(httpErr.HTTPCode)
|
||||
io.WriteString(w, httpErr.JSON()+"\n")
|
||||
}
|
||||
|
@ -461,7 +458,7 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request, v *visitor)
|
|||
unifiedpush := readBoolParam(r, false, "x-unifiedpush", "unifiedpush", "up") // see PUT/POST too!
|
||||
if unifiedpush {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
_, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n")
|
||||
return err
|
||||
}
|
||||
|
@ -538,7 +535,7 @@ func (s *Server) handleFile(w http.ResponseWriter, r *http.Request, v *visitor)
|
|||
}
|
||||
}
|
||||
w.Header().Set("Content-Length", fmt.Sprintf("%d", stat.Size()))
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if r.Method == http.MethodGet {
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
|
@ -969,14 +966,16 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
}
|
||||
return nil
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
|
||||
}
|
||||
defer func() {
|
||||
for i, subscriberID := range subscriberIDs {
|
||||
|
@ -991,6 +990,8 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
|||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case <-r.Context().Done():
|
||||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
|
@ -1033,8 +1034,20 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Subscription connections can be canceled externally, see topic.CancelSubscribers
|
||||
subscriberContext, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
// Use errgroup to run WebSocket reader and writer in Go routines
|
||||
var wlock sync.Mutex
|
||||
g, ctx := errgroup.WithContext(context.Background())
|
||||
g, gctx := errgroup.WithContext(context.Background())
|
||||
g.Go(func() error {
|
||||
<-subscriberContext.Done()
|
||||
log.Trace("%s Cancel received, closing subscriber connection", logHTTPPrefix(v, r))
|
||||
conn.Close()
|
||||
return &websocket.CloseError{Code: websocket.CloseNormalClosure, Text: "subscription was canceled"}
|
||||
})
|
||||
g.Go(func() error {
|
||||
pongWait := s.config.KeepaliveInterval + wsPongWait
|
||||
conn.SetReadLimit(wsReadLimit)
|
||||
|
@ -1050,6 +1063,11 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-gctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
}
|
||||
})
|
||||
g.Go(func() error {
|
||||
|
@ -1064,7 +1082,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-gctx.Done():
|
||||
return nil
|
||||
case <-time.After(s.config.KeepaliveInterval):
|
||||
v.Keepalive()
|
||||
|
@ -1085,13 +1103,13 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
|||
}
|
||||
return conn.WriteJSON(msg)
|
||||
}
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||
if poll {
|
||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||
}
|
||||
subscriberIDs := make([]int, 0)
|
||||
for _, t := range topics {
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub))
|
||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v.MaybeUserID(), cancel))
|
||||
}
|
||||
defer func() {
|
||||
for i, subscriberID := range subscriberIDs {
|
||||
|
@ -1193,11 +1211,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
|
|||
if len(parts) < 2 {
|
||||
return nil, errHTTPBadRequestTopicInvalid
|
||||
}
|
||||
topics, err := s.topicsFromIDs(parts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics[0], nil
|
||||
return s.topicFromID(parts[1])
|
||||
}
|
||||
|
||||
func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
|
||||
|
@ -1232,6 +1246,14 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
|
|||
return topics, nil
|
||||
}
|
||||
|
||||
func (s *Server) topicFromID(id string) (*topic, error) {
|
||||
topics, err := s.topicsFromIDs(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return topics[0], nil
|
||||
}
|
||||
|
||||
func (s *Server) execManager() {
|
||||
log.Debug("Manager: Starting")
|
||||
defer log.Debug("Manager: Finished")
|
||||
|
|
|
@ -2,7 +2,6 @@ package server
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"heckel.io/ntfy/log"
|
||||
"heckel.io/ntfy/user"
|
||||
"heckel.io/ntfy/util"
|
||||
|
@ -331,6 +330,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
|||
if v.user.Tier == nil {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
// CHeck if we are allowed to reserve this topic
|
||||
if err := s.userManager.CheckAllowAccess(v.user.Name, req.Topic); err != nil {
|
||||
return errHTTPConflictTopicReserved
|
||||
}
|
||||
|
@ -346,9 +346,16 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
|
|||
return errHTTPTooManyRequestsLimitReservations
|
||||
}
|
||||
}
|
||||
// Actually add the reservation
|
||||
if err := s.userManager.AddReservation(v.user.Name, req.Topic, everyone); err != nil {
|
||||
return err
|
||||
}
|
||||
// Kill existing subscribers
|
||||
t, err := s.topicFromID(req.Topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
t.CancelSubscribers(v.user.ID)
|
||||
return s.writeJSON(w, newSuccessResponse())
|
||||
}
|
||||
|
||||
|
@ -402,13 +409,10 @@ func (s *Server) publishSyncEvent(v *visitor) error {
|
|||
return nil
|
||||
}
|
||||
log.Trace("Publishing sync event to user %s's sync topic %s", v.user.Name, v.user.SyncTopic)
|
||||
topics, err := s.topicsFromIDs(v.user.SyncTopic)
|
||||
syncTopic, err := s.topicFromID(v.user.SyncTopic)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(topics) == 0 {
|
||||
return errors.New("cannot retrieve sync topic")
|
||||
}
|
||||
syncTopic := topics[0]
|
||||
messageBytes, err := json.Marshal(&apiAccountSyncTopicResponse{Event: syncTopicAccountSyncEvent})
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -496,3 +496,72 @@ func TestAccount_Reservation_PublishByAnonymousFails(t *testing.T) {
|
|||
rr = request(t, s, "POST", "/mytopic", `Howdy`, nil)
|
||||
require.Equal(t, 403, rr.Code)
|
||||
}
|
||||
|
||||
func TestAccount_Reservation_Add_Kills_Other_Subscribers(t *testing.T) {
|
||||
conf := newTestConfigWithAuthFile(t)
|
||||
conf.AuthDefault = user.PermissionReadWrite
|
||||
conf.EnableSignup = true
|
||||
s := newTestServer(t, conf)
|
||||
|
||||
// Create user with tier
|
||||
rr := request(t, s, "POST", "/v1/account", `{"username":"phil", "password":"mypass"}`, nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
require.Nil(t, s.userManager.CreateTier(&user.Tier{
|
||||
Code: "pro",
|
||||
MessagesLimit: 20,
|
||||
ReservationsLimit: 2,
|
||||
}))
|
||||
require.Nil(t, s.userManager.ChangeTier("phil", "pro"))
|
||||
|
||||
// Subscribe anonymously
|
||||
anonCh, userCh := make(chan bool), make(chan bool)
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/mytopic/json", ``, nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
messages := toMessages(t, rr.Body.String())
|
||||
require.Equal(t, 2, len(messages)) // This is the meat. We should NOT receive the second message!
|
||||
require.Equal(t, "open", messages[0].Event)
|
||||
require.Equal(t, "message before reservation", messages[1].Message)
|
||||
anonCh <- true
|
||||
}()
|
||||
|
||||
// Subscribe with user
|
||||
go func() {
|
||||
rr := request(t, s, "GET", "/mytopic/json", ``, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "mypass"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
messages := toMessages(t, rr.Body.String())
|
||||
require.Equal(t, 3, len(messages))
|
||||
require.Equal(t, "open", messages[0].Event)
|
||||
require.Equal(t, "message before reservation", messages[1].Message)
|
||||
require.Equal(t, "message after reservation", messages[2].Message)
|
||||
userCh <- true
|
||||
}()
|
||||
|
||||
// Publish message (before reservation)
|
||||
time.Sleep(700 * time.Millisecond) // Wait for subscribers
|
||||
rr = request(t, s, "POST", "/mytopic", "message before reservation", nil)
|
||||
require.Equal(t, 200, rr.Code)
|
||||
time.Sleep(700 * time.Millisecond) // Wait for subscribers to receive message
|
||||
|
||||
// Reserve a topic
|
||||
rr = request(t, s, "POST", "/v1/account/reservation", `{"topic": "mytopic", "everyone":"deny-all"}`, map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "mypass"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Everyone but phil should be killed
|
||||
<-anonCh
|
||||
|
||||
// Publish a message
|
||||
rr = request(t, s, "POST", "/mytopic", "message after reservation", map[string]string{
|
||||
"Authorization": util.BasicAuth("phil", "mypass"),
|
||||
})
|
||||
require.Equal(t, 200, rr.Code)
|
||||
|
||||
// Kill user Go routine
|
||||
s.topics["mytopic"].CancelSubscribers("<invalid>")
|
||||
<-userCh
|
||||
}
|
||||
|
|
|
@ -10,10 +10,16 @@ import (
|
|||
// can publish a message
|
||||
type topic struct {
|
||||
ID string
|
||||
subscribers map[int]subscriber
|
||||
subscribers map[int]*topicSubscriber
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type topicSubscriber struct {
|
||||
userID string // User ID associated with this subscription, may be empty
|
||||
subscriber subscriber
|
||||
cancel func()
|
||||
}
|
||||
|
||||
// subscriber is a function that is called for every new message on a topic
|
||||
type subscriber func(v *visitor, msg *message) error
|
||||
|
||||
|
@ -21,16 +27,20 @@ type subscriber func(v *visitor, msg *message) error
|
|||
func newTopic(id string) *topic {
|
||||
return &topic{
|
||||
ID: id,
|
||||
subscribers: make(map[int]subscriber),
|
||||
subscribers: make(map[int]*topicSubscriber),
|
||||
}
|
||||
}
|
||||
|
||||
// Subscribe subscribes to this topic
|
||||
func (t *topic) Subscribe(s subscriber) int {
|
||||
func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscriberID := rand.Int()
|
||||
t.subscribers[subscriberID] = s
|
||||
t.subscribers[subscriberID] = &topicSubscriber{
|
||||
userID: userID, // May be empty
|
||||
subscriber: s,
|
||||
cancel: cancel,
|
||||
}
|
||||
return subscriberID
|
||||
}
|
||||
|
||||
|
@ -56,7 +66,7 @@ func (t *topic) Publish(v *visitor, m *message) error {
|
|||
if err := s(v, m); err != nil {
|
||||
log.Warn("%s Error forwarding to subscriber", logMessagePrefix(v, m))
|
||||
}
|
||||
}(s)
|
||||
}(s.subscriber)
|
||||
}
|
||||
} else {
|
||||
log.Trace("%s No stream or WebSocket subscribers, not forwarding", logMessagePrefix(v, m))
|
||||
|
@ -72,13 +82,29 @@ func (t *topic) SubscribersCount() int {
|
|||
return len(t.subscribers)
|
||||
}
|
||||
|
||||
// subscribersCopy returns a shallow copy of the subscribers map
|
||||
func (t *topic) subscribersCopy() map[int]subscriber {
|
||||
// CancelSubscribers calls the cancel function for all subscribers, forcing
|
||||
func (t *topic) CancelSubscribers(exceptUserID string) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscribers := make(map[int]subscriber)
|
||||
for k, v := range t.subscribers {
|
||||
subscribers[k] = v
|
||||
for _, s := range t.subscribers {
|
||||
if s.userID != exceptUserID {
|
||||
log.Trace("Canceling subscriber %s", s.userID)
|
||||
s.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// subscribersCopy returns a shallow copy of the subscribers map
|
||||
func (t *topic) subscribersCopy() map[int]*topicSubscriber {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
subscribers := make(map[int]*topicSubscriber)
|
||||
for k, sub := range t.subscribers {
|
||||
subscribers[k] = &topicSubscriber{
|
||||
userID: sub.userID,
|
||||
subscriber: sub.subscriber,
|
||||
cancel: sub.cancel,
|
||||
}
|
||||
}
|
||||
return subscribers
|
||||
}
|
||||
|
|
|
@ -228,12 +228,24 @@ func (v *visitor) ResetStats() {
|
|||
}
|
||||
}
|
||||
|
||||
// SetUser sets the visitors user to the given value
|
||||
func (v *visitor) SetUser(u *user.User) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
v.user = u
|
||||
}
|
||||
|
||||
// MaybeUserID returns the user ID of the visitor (if any). If this is an anonymous visitor,
|
||||
// an empty string is returned.
|
||||
func (v *visitor) MaybeUserID() string {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
if v.user != nil {
|
||||
return v.user.ID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (v *visitor) Limits() *visitorLimits {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
|
Loading…
Reference in a new issue