diff --git a/server/server.go b/server/server.go
index 0b3381c..ec4ca67 100644
--- a/server/server.go
+++ b/server/server.go
@@ -571,7 +571,7 @@ func (s *Server) handleMatrixDiscovery(w http.ResponseWriter) error {
 }
 
 func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*message, error) {
-	vRate, ok := r.Context().Value("vRate").(*visitor)
+	vrate, ok := r.Context().Value("vRate").(*visitor)
 	if !ok {
 		return nil, errHTTPInternalError
 	}
@@ -579,8 +579,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 	if !ok {
 		return nil, errHTTPInternalError
 	}
-
-	if !vRate.MessageAllowed() {
+	if !vrate.MessageAllowed() {
 		return nil, errHTTPTooManyRequestsLimitMessages
 	}
 	body, err := util.Peek(r.Body, s.config.MessageLimit)
@@ -588,7 +587,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 		return nil, err
 	}
 	m := newDefaultMessage(t.ID, "")
-	cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vRate, m)
+	cache, firebase, email, unifiedpush, err := s.parsePublishParams(r, vrate, m)
 	if err != nil {
 		return nil, err
 	}
@@ -607,7 +606,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 		m.Message = emptyMessageBody
 	}
 	delayed := m.Time > time.Now().Unix()
-	ev := logvrm(vRate, r, m).
+	ev := logvrm(vrate, r, m).
 		Tag(tagPublish).
 		Fields(log.Context{
 			"message_delayed":     delayed,
@@ -625,7 +624,7 @@ func (s *Server) handlePublishWithoutResponse(r *http.Request, v *visitor) (*mes
 			return nil, err
 		}
 		if s.firebaseClient != nil && firebase {
-			go s.sendToFirebase(vRate, m)
+			go s.sendToFirebase(vrate, m)
 		}
 		if s.smtpSender != nil && email != "" {
 			go s.sendEmail(v, m, email)
@@ -657,7 +656,6 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
 	if err != nil {
 		return err
 	}
-
 	return s.writeJSON(w, m)
 }
 
@@ -766,7 +764,7 @@ func (s *Server) parsePublishParams(r *http.Request, vRate *visitor, m *message)
 	if err != nil {
 		return false, false, "", false, errHTTPBadRequestPriorityInvalid
 	}
-	m.Tags = readCommaSeperatedParam(r, "x-tags", "tags", "tag", "ta")
+	m.Tags = readCommaSeparatedParam(r, "x-tags", "tags", "tag", "ta")
 	delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
 	if delayStr != "" {
 		if !cache {
@@ -986,6 +984,12 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 		}
 		return nil
 	}
+	for _, t := range topics {
+		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
+		if subscriberRateLimited {
+			t.SetRateVisitor(v)
+		}
+	}
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	w.Header().Set("Content-Type", contentType+"; charset=utf-8")                    // Android/Volley client needs charset!
 	if poll {
@@ -995,8 +999,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
 	defer cancel()
 	subscriberIDs := make([]int, 0)
 	for _, t := range topics {
-		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
-		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
+		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
 	}
 	defer func() {
 		for i, subscriberID := range subscriberIDs {
@@ -1122,14 +1125,19 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
 		}
 		return conn.WriteJSON(msg)
 	}
+	for _, t := range topics {
+		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
+		if subscriberRateLimited {
+			t.SetRateVisitor(v)
+		}
+	}
 	w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
 	if poll {
 		return s.sendOldMessages(topics, since, scheduled, v, sub)
 	}
 	subscriberIDs := make([]int, 0)
 	for _, t := range topics {
-		subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
-		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
+		subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel))
 	}
 	defer func() {
 		for i, subscriberID := range subscriberIDs {
@@ -1161,7 +1169,7 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
 	if err != nil {
 		return
 	}
-	subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
+	subscriberTopics = readCommaSeparatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
 	return
 }
 
diff --git a/server/server_middleware.go b/server/server_middleware.go
index 1d2f734..712223c 100644
--- a/server/server_middleware.go
+++ b/server/server_middleware.go
@@ -26,8 +26,8 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
 			return err
 		}
 		vrate := v
-		if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
-			vrate = topicCountsAgainst
+		if rateVisitor := t.RateVisitor(); rateVisitor != nil {
+			vrate = rateVisitor
 		}
 		r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t))
 
diff --git a/server/server_test.go b/server/server_test.go
index fe5a49f..7f2665a 100644
--- a/server/server_test.go
+++ b/server/server_test.go
@@ -1894,15 +1894,17 @@ func TestServer_SubscriberRateLimiting(t *testing.T) {
 	c.VisitorRequestLimitBurst = 3
 	s := newTestServer(t, c)
 
+	// "Register" visitor 1.2.3.4 to topic "subscriber1topic" as a rate limit visitor
 	subscriber1Fn := func(r *http.Request) {
 		r.RemoteAddr = "1.2.3.4"
 	}
 	rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
-		"Subscriber-Rate-Limit-Topics": "mytopic1",
+		"Subscriber-Rate-Limit-Topics": "subscriber1topic",
 	}, subscriber1Fn)
 	require.Equal(t, 200, rr.Code)
 	require.Equal(t, "", rr.Body.String())
 
+	// "Register" visitor 8.7.7.1 to topic "upSUB2topic" as a rate limit visitor (implicitly via topic name)
 	subscriber2Fn := func(r *http.Request) {
 		r.RemoteAddr = "8.7.7.1"
 	}
@@ -1910,20 +1912,28 @@ func TestServer_SubscriberRateLimiting(t *testing.T) {
 	require.Equal(t, 200, rr.Code)
 	require.Equal(t, "", rr.Body.String())
 
-	for i := 0; i < 3; i++ {
+	// Publish 2 messages to "subscriber1topic" as visitor 9.9.9.9. It'd be 3 normally, but the
+	// GET request before is also counted towards the request limiter.
+	for i := 0; i < 2; i++ {
 		rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil)
 		require.Equal(t, 200, rr.Code)
 	}
 	rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil)
 	require.Equal(t, 429, rr.Code)
 
-	for i := 0; i < 3; i++ {
+	// Publish another 2 messages to "upSUB2topic" as visitor 9.9.9.9
+	for i := 0; i < 2; i++ {
 		rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil)
 		require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor!
 	}
 	rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil)
 	require.Equal(t, 429, rr.Code)
 
+	// Hurray! At this point, visitor 9.9.9.9 has published 4 messages, even though
+	// VisitorRequestLimitBurst is 3. That means it's working.
+
+	// Now let's confirm that so far we haven't used up any of visitor 9.9.9.9's request limiter
+	// by publishing another 3 requests from it.
 	for i := 0; i < 3; i++ {
 		rr := request(t, s, "PUT", "/some-other-topic", "some message", nil)
 		require.Equal(t, 200, rr.Code)
@@ -1959,18 +1969,18 @@ func newTestServer(t *testing.T, config *Config) *Server {
 
 func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder {
 	rr := httptest.NewRecorder()
-	req, err := http.NewRequest(method, url, strings.NewReader(body))
+	r, err := http.NewRequest(method, url, strings.NewReader(body))
 	if err != nil {
 		t.Fatal(err)
 	}
-	req.RemoteAddr = "9.9.9.9" // Used for tests
+	r.RemoteAddr = "9.9.9.9" // Used for tests
 	for k, v := range headers {
-		req.Header.Set(k, v)
+		r.Header.Set(k, v)
 	}
 	for _, f := range fn {
-		f(req)
+		f(r)
 	}
-	s.handle(rr, req)
+	s.handle(rr, r)
 	return rr
 }
 
diff --git a/server/topic.go b/server/topic.go
index 85af005..3b0cb54 100644
--- a/server/topic.go
+++ b/server/topic.go
@@ -19,10 +19,9 @@ type topic struct {
 }
 
 type topicSubscriber struct {
-	subscriber          subscriber
-	visitor             *visitor // User ID associated with this subscription, may be empty
-	cancel              func()
-	subscriberRateLimit bool
+	subscriber subscriber
+	visitor    *visitor // User ID associated with this subscription, may be empty
+	cancel     func()
 }
 
 // subscriber is a function that is called for every new message on a topic
@@ -37,39 +36,40 @@ func newTopic(id string) *topic {
 }
 
 // Subscribe subscribes to this topic
-func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscriberRateLimit bool) int {
+func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func()) int {
 	t.mu.Lock()
 	defer t.mu.Unlock()
 	subscriberID := rand.Int()
 	t.subscribers[subscriberID] = &topicSubscriber{
-		visitor:             visitor, // May be empty
-		subscriber:          s,
-		cancel:              cancel,
-		subscriberRateLimit: subscriberRateLimit,
+		visitor:    visitor, // May be empty
+		subscriber: s,
+		cancel:     cancel,
 	}
-
-	// if no subscriber is already handling the rate limit
-	if t.rateVisitor == nil && subscriberRateLimit {
-		t.rateVisitor = visitor
-		t.rateVisitorExpires = time.Time{}
-	}
-
 	return subscriberID
 }
 
 func (t *topic) Stale() bool {
 	t.mu.Lock()
 	defer t.mu.Unlock()
-	// if Time is initialized (not the zero value) and the expiry time has passed
-	if !t.rateVisitorExpires.IsZero() && t.rateVisitorExpires.Before(time.Now()) {
+	if t.rateVisitorExpires.Before(time.Now()) {
 		t.rateVisitor = nil
 	}
 	return len(t.subscribers) == 0 && t.rateVisitor == nil
 }
 
-func (t *topic) Billee() *visitor {
-	t.mu.RLock()
-	defer t.mu.RUnlock()
+func (t *topic) SetRateVisitor(v *visitor) {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	t.rateVisitor = v
+	t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration)
+}
+
+func (t *topic) RateVisitor() *visitor {
+	t.mu.Lock()
+	defer t.mu.Unlock()
+	if t.rateVisitorExpires.Before(time.Now()) {
+		t.rateVisitor = nil
+	}
 	return t.rateVisitor
 }
 
@@ -77,24 +77,7 @@ func (t *topic) Billee() *visitor {
 func (t *topic) Unsubscribe(id int) {
 	t.mu.Lock()
 	defer t.mu.Unlock()
-
-	deletingSub := t.subscribers[id]
 	delete(t.subscribers, id)
-
-	// look for an active subscriber (in random order) that wants to handle the rate limit
-	for _, v := range t.subscribers {
-		if v.subscriberRateLimit {
-			t.rateVisitor = v.visitor
-			t.rateVisitorExpires = time.Time{}
-			return
-		}
-	}
-
-	// if no active subscriber is found, count it towards the leaving subscriber
-	if deletingSub.subscriberRateLimit {
-		t.rateVisitor = deletingSub.visitor
-		t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration)
-	}
 }
 
 // Publish asynchronously publishes to all subscribers
diff --git a/server/util.go b/server/util.go
index 26d0854..8ec258f 100644
--- a/server/util.go
+++ b/server/util.go
@@ -16,7 +16,7 @@ func readBoolParam(r *http.Request, defaultValue bool, names ...string) bool {
 	return value == "1" || value == "yes" || value == "true"
 }
 
-func readCommaSeperatedParam(r *http.Request, names ...string) (params []string) {
+func readCommaSeparatedParam(r *http.Request, names ...string) (params []string) {
 	paramStr := readParam(r, names...)
 	if paramStr != "" {
 		params = make([]string, 0)