Add test, fails
This commit is contained in:
parent
4ab450309f
commit
29340e7e24
5 changed files with 89 additions and 45 deletions
|
@ -9,6 +9,12 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/emersion/go-smtp"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"heckel.io/ntfy/log"
|
||||||
|
"heckel.io/ntfy/user"
|
||||||
|
"heckel.io/ntfy/util"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -24,13 +30,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"github.com/emersion/go-smtp"
|
|
||||||
"github.com/gorilla/websocket"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"heckel.io/ntfy/log"
|
|
||||||
"heckel.io/ntfy/user"
|
|
||||||
"heckel.io/ntfy/util"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server is the main server, providing the UI and API for ntfy
|
// Server is the main server, providing the UI and API for ntfy
|
||||||
|
@ -105,15 +104,15 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
firebaseControlTopic = "~control" // See Android if changed
|
firebaseControlTopic = "~control" // See Android if changed
|
||||||
firebasePollTopic = "~poll" // See iOS if changed
|
firebasePollTopic = "~poll" // See iOS if changed
|
||||||
emptyMessageBody = "triggered" // Used if message body is empty
|
emptyMessageBody = "triggered" // Used if message body is empty
|
||||||
newMessageBody = "New message" // Used in poll requests as generic message
|
newMessageBody = "New message" // Used in poll requests as generic message
|
||||||
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
defaultAttachmentMessage = "You received a file: %s" // Used if message body is empty, and there is an attachment
|
||||||
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
||||||
jsonBodyBytesLimit = 16384
|
jsonBodyBytesLimit = 16384
|
||||||
subscriberBilledTopicPrefix = "up_"
|
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
|
||||||
subscriberBilledValidity = 12 * time.Hour
|
rateVisitorExpiryDuration = 12 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
// WebSocket constants
|
// WebSocket constants
|
||||||
|
@ -996,7 +995,7 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
defer cancel()
|
defer cancel()
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
|
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
|
||||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -1129,7 +1128,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
}
|
}
|
||||||
subscriberIDs := make([]int, 0)
|
subscriberIDs := make([]int, 0)
|
||||||
for _, t := range topics {
|
for _, t := range topics {
|
||||||
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, subscriberBilledTopicPrefix) // temporarily do prefix as well
|
subscriberRateLimited := util.Contains(subscriberRateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) // temporarily do prefix as well
|
||||||
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
subscriberIDs = append(subscriberIDs, t.Subscribe(sub, v, cancel, subscriberRateLimited))
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
|
@ -1162,7 +1161,6 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
|
subscriberTopics = readCommaSeperatedParam(r, "subscriber-rate-limit-topics", "x-subscriber-rate-limit-topics", "srlt")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -40,7 +40,7 @@ func (s *Server) execManager() {
|
||||||
if ev.IsTrace() {
|
if ev.IsTrace() {
|
||||||
expiryMessage := ""
|
expiryMessage := ""
|
||||||
if subs == 0 {
|
if subs == 0 {
|
||||||
expiryTime := time.Until(t.vRateExpires)
|
expiryTime := time.Until(t.rateVisitorExpires)
|
||||||
expiryMessage = ", expires in " + expiryTime.String()
|
expiryMessage = ", expires in " + expiryTime.String()
|
||||||
}
|
}
|
||||||
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
|
ev.Trace("- topic %s: %d subscribers%s", t.ID, subs, expiryMessage)
|
||||||
|
|
|
@ -25,15 +25,15 @@ func (s *Server) limitRequestsWithTopic(next handleFunc) handleFunc {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
vRate := v
|
vrate := v
|
||||||
if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
|
if topicCountsAgainst := t.Billee(); topicCountsAgainst != nil {
|
||||||
vRate = topicCountsAgainst
|
vrate = topicCountsAgainst
|
||||||
}
|
}
|
||||||
r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vRate), "topic", t))
|
r = r.WithContext(context.WithValue(context.WithValue(r.Context(), "vRate", vrate), "topic", t))
|
||||||
|
|
||||||
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
if util.ContainsIP(s.config.VisitorRequestExemptIPAddrs, v.ip) {
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
} else if !vRate.RequestAllowed() {
|
} else if !vrate.RequestAllowed() {
|
||||||
return errHTTPTooManyRequestsLimitRequests
|
return errHTTPTooManyRequestsLimitRequests
|
||||||
}
|
}
|
||||||
return next(w, r, v)
|
return next(w, r, v)
|
||||||
|
|
|
@ -1889,6 +1889,49 @@ func TestServer_AnonymousUser_And_NonTierUser_Are_Same_Visitor(t *testing.T) {
|
||||||
require.Equal(t, int64(2), account.Stats.Messages)
|
require.Equal(t, int64(2), account.Stats.Messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_SubscriberRateLimiting(t *testing.T) {
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.VisitorRequestLimitBurst = 3
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
subscriber1Fn := func(r *http.Request) {
|
||||||
|
r.RemoteAddr = "1.2.3.4"
|
||||||
|
}
|
||||||
|
rr := request(t, s, "GET", "/subscriber1topic/json?poll=1", "", map[string]string{
|
||||||
|
"Subscriber-Rate-Limit-Topics": "mytopic1",
|
||||||
|
}, subscriber1Fn)
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
require.Equal(t, "", rr.Body.String())
|
||||||
|
|
||||||
|
subscriber2Fn := func(r *http.Request) {
|
||||||
|
r.RemoteAddr = "8.7.7.1"
|
||||||
|
}
|
||||||
|
rr = request(t, s, "GET", "/upSUB2topic/json?poll=1", "", nil, subscriber2Fn)
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
require.Equal(t, "", rr.Body.String())
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
rr := request(t, s, "PUT", "/subscriber1topic", "some message", nil)
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
}
|
||||||
|
rr = request(t, s, "PUT", "/subscriber1topic", "some message", nil)
|
||||||
|
require.Equal(t, 429, rr.Code)
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
rr := request(t, s, "PUT", "/upSUB2topic", "some message", nil)
|
||||||
|
require.Equal(t, 200, rr.Code) // If we fail here, handlePublish is using the wrong visitor!
|
||||||
|
}
|
||||||
|
rr = request(t, s, "PUT", "/upSUB2topic", "some message", nil)
|
||||||
|
require.Equal(t, 429, rr.Code)
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
rr := request(t, s, "PUT", "/some-other-topic", "some message", nil)
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
}
|
||||||
|
rr = request(t, s, "PUT", "/some-other-topic", "some message", nil)
|
||||||
|
require.Equal(t, 429, rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
func newTestConfig(t *testing.T) *Config {
|
func newTestConfig(t *testing.T) *Config {
|
||||||
conf := NewConfig()
|
conf := NewConfig()
|
||||||
conf.BaseURL = "http://127.0.0.1:12345"
|
conf.BaseURL = "http://127.0.0.1:12345"
|
||||||
|
@ -1914,7 +1957,7 @@ func newTestServer(t *testing.T, config *Config) *Server {
|
||||||
return server
|
return server
|
||||||
}
|
}
|
||||||
|
|
||||||
func request(t *testing.T, s *Server, method, url, body string, headers map[string]string) *httptest.ResponseRecorder {
|
func request(t *testing.T, s *Server, method, url, body string, headers map[string]string, fn ...func(r *http.Request)) *httptest.ResponseRecorder {
|
||||||
rr := httptest.NewRecorder()
|
rr := httptest.NewRecorder()
|
||||||
req, err := http.NewRequest(method, url, strings.NewReader(body))
|
req, err := http.NewRequest(method, url, strings.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1924,6 +1967,9 @@ func request(t *testing.T, s *Server, method, url, body string, headers map[stri
|
||||||
for k, v := range headers {
|
for k, v := range headers {
|
||||||
req.Header.Set(k, v)
|
req.Header.Set(k, v)
|
||||||
}
|
}
|
||||||
|
for _, f := range fn {
|
||||||
|
f(req)
|
||||||
|
}
|
||||||
s.handle(rr, req)
|
s.handle(rr, req)
|
||||||
return rr
|
return rr
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,11 +11,11 @@ import (
|
||||||
// topic represents a channel to which subscribers can subscribe, and publishers
|
// topic represents a channel to which subscribers can subscribe, and publishers
|
||||||
// can publish a message
|
// can publish a message
|
||||||
type topic struct {
|
type topic struct {
|
||||||
ID string
|
ID string
|
||||||
subscribers map[int]*topicSubscriber
|
subscribers map[int]*topicSubscriber
|
||||||
vRate *visitor
|
rateVisitor *visitor
|
||||||
vRateExpires time.Time
|
rateVisitorExpires time.Time
|
||||||
mu sync.Mutex
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type topicSubscriber struct {
|
type topicSubscriber struct {
|
||||||
|
@ -49,9 +49,9 @@ func (t *topic) Subscribe(s subscriber, visitor *visitor, cancel func(), subscri
|
||||||
}
|
}
|
||||||
|
|
||||||
// if no subscriber is already handling the rate limit
|
// if no subscriber is already handling the rate limit
|
||||||
if t.vRate == nil && subscriberRateLimit {
|
if t.rateVisitor == nil && subscriberRateLimit {
|
||||||
t.vRate = visitor
|
t.rateVisitor = visitor
|
||||||
t.vRateExpires = time.Time{}
|
t.rateVisitorExpires = time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
return subscriberID
|
return subscriberID
|
||||||
|
@ -61,16 +61,16 @@ func (t *topic) Stale() bool {
|
||||||
t.mu.Lock()
|
t.mu.Lock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.Unlock()
|
||||||
// if Time is initialized (not the zero value) and the expiry time has passed
|
// if Time is initialized (not the zero value) and the expiry time has passed
|
||||||
if !t.vRateExpires.IsZero() && t.vRateExpires.Before(time.Now()) {
|
if !t.rateVisitorExpires.IsZero() && t.rateVisitorExpires.Before(time.Now()) {
|
||||||
t.vRate = nil
|
t.rateVisitor = nil
|
||||||
}
|
}
|
||||||
return len(t.subscribers) == 0 && t.vRate == nil
|
return len(t.subscribers) == 0 && t.rateVisitor == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *topic) Billee() *visitor {
|
func (t *topic) Billee() *visitor {
|
||||||
t.mu.Lock()
|
t.mu.RLock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.RUnlock()
|
||||||
return t.vRate
|
return t.rateVisitor
|
||||||
}
|
}
|
||||||
|
|
||||||
// Unsubscribe removes the subscription from the list of subscribers
|
// Unsubscribe removes the subscription from the list of subscribers
|
||||||
|
@ -84,16 +84,16 @@ func (t *topic) Unsubscribe(id int) {
|
||||||
// look for an active subscriber (in random order) that wants to handle the rate limit
|
// look for an active subscriber (in random order) that wants to handle the rate limit
|
||||||
for _, v := range t.subscribers {
|
for _, v := range t.subscribers {
|
||||||
if v.subscriberRateLimit {
|
if v.subscriberRateLimit {
|
||||||
t.vRate = v.visitor
|
t.rateVisitor = v.visitor
|
||||||
t.vRateExpires = time.Time{}
|
t.rateVisitorExpires = time.Time{}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if no active subscriber is found, count it towards the leaving subscriber
|
// if no active subscriber is found, count it towards the leaving subscriber
|
||||||
if deletingSub.subscriberRateLimit {
|
if deletingSub.subscriberRateLimit {
|
||||||
t.vRate = deletingSub.visitor
|
t.rateVisitor = deletingSub.visitor
|
||||||
t.vRateExpires = time.Now().Add(subscriberBilledValidity)
|
t.rateVisitorExpires = time.Now().Add(rateVisitorExpiryDuration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,8 +123,8 @@ func (t *topic) Publish(v *visitor, m *message) error {
|
||||||
|
|
||||||
// SubscribersCount returns the number of subscribers to this topic
|
// SubscribersCount returns the number of subscribers to this topic
|
||||||
func (t *topic) SubscribersCount() int {
|
func (t *topic) SubscribersCount() int {
|
||||||
t.mu.Lock()
|
t.mu.RLock()
|
||||||
defer t.mu.Unlock()
|
defer t.mu.RUnlock()
|
||||||
return len(t.subscribers)
|
return len(t.subscribers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue