diff --git a/server/topic.go b/server/topic.go index e093a61..476f28a 100644 --- a/server/topic.go +++ b/server/topic.go @@ -1,11 +1,12 @@ package server import ( - "heckel.io/ntfy/log" - "heckel.io/ntfy/util" "math/rand" "sync" "time" + + "heckel.io/ntfy/log" + "heckel.io/ntfy/util" ) const ( @@ -45,9 +46,23 @@ func newTopic(id string) *topic { // Subscribe subscribes to this topic func (t *topic) Subscribe(s subscriber, userID string, cancel func()) int { + max_retries := 5 + retries := 1 t.mu.Lock() defer t.mu.Unlock() + subscriberID := rand.Int() + // simple check for existing id in maps + for { + _, ok := t.subscribers[subscriberID] + if ok && retries <= max_retries { + subscriberID = rand.Int() + retries++ + } else { + break + } + } + t.subscribers[subscriberID] = &topicSubscriber{ userID: userID, // May be empty subscriber: s, diff --git a/server/topic_test.go b/server/topic_test.go index b22bad5..400f4b6 100644 --- a/server/topic_test.go +++ b/server/topic_test.go @@ -1,10 +1,12 @@ package server import ( - "github.com/stretchr/testify/require" + "math/rand" "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestTopic_CancelSubscribers(t *testing.T) { @@ -39,3 +41,30 @@ func TestTopic_Keepalive(t *testing.T) { require.True(t, to.LastAccess().Unix() >= time.Now().Unix()-2) require.True(t, to.LastAccess().Unix() <= time.Now().Unix()+2) } + +func TestTopic_Subscribe_duplicateID(t *testing.T) { + t.Parallel() + + to := newTopic("mytopic") + + // fix random seed to force same number generation + rand.Seed(1) + a := rand.Int() + to.subscribers[a] = &topicSubscriber{ + userID: "a", + subscriber: nil, + cancel: func() {}, + } + + subFn := func(v *visitor, msg *message) error { + return nil + } + + // force rand.Int to generate the same id once more + rand.Seed(1) + id := to.Subscribe(subFn, "b", func() {}) + res := to.subscribers[id] + + require.False(t, id == a) + require.True(t, res.userID == "b") +}