Only set rate visitor if allowed

This commit is contained in:
binwiederhier 2023-02-24 14:45:30 -05:00
parent 2329695a47
commit bfc3983d06
4 changed files with 151 additions and 17 deletions

View file

@ -112,7 +112,6 @@ const (
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
jsonBodyBytesLimit = 16384 jsonBodyBytesLimit = 16384
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
rateTopicsWildcard = "*" // Allows defining all topics in the request subscriber-rate-limited topics
) )
// WebSocket constants // WebSocket constants
@ -977,7 +976,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
} }
return nil return nil
} }
registerRateVisitors(topics, rateTopics, v) if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
return err
}
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset! w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
if poll { if poll {
@ -1113,7 +1114,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
} }
return conn.WriteJSON(msg) return conn.WriteJSON(msg)
} }
registerRateVisitors(topics, rateTopics, v) if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
return err
}
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
if poll { if poll {
return s.sendOldMessages(topics, since, scheduled, v, sub) return s.sendOldMessages(topics, since, scheduled, v, sub)
@ -1156,23 +1159,62 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
return return
} }
// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic // maybeSetRateVisitors sets the rate visitor on a topic (v.SetRateVisitor), indicating that all messages published
// will be rate limited against the rate visitor instead of the publishing visitor. // to that topic will be rate limited against the rate visitor instead of the publishing visitor.
//
// Setting the rate visitor is ony allowed if
// - auth-file is not set (everything is open by default)
// - the topic is reserved, and v.user is the owner
// - the topic is not reserved, and v.user has write access
// //
// Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition // Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
// until the Android app will send the "Rate-Topics" header. // until the Android app will send the "Rate-Topics" header.
func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) { func (s *Server) maybeSetRateVisitors(r *http.Request, v *visitor, topics []*topic, rateTopics []string) error {
if len(rateTopics) == 1 && rateTopics[0] == rateTopicsWildcard { // Make a list of topics that we'll actually set the RateVisitor on
for _, t := range topics { eligibleRateTopics := make([]*topic, 0)
t.SetRateVisitor(v) for _, t := range topics {
} if strings.HasPrefix(t.ID, unifiedPushTopicPrefix) || util.Contains(rateTopics, t.ID) {
} else { eligibleRateTopics = append(eligibleRateTopics, t)
for _, t := range topics {
if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) {
t.SetRateVisitor(v)
}
} }
} }
if len(eligibleRateTopics) == 0 {
return nil
}
// If access controls are turned off, v has access to everything, and we can set the rate visitor
if s.userManager == nil {
return s.setRateVisitors(r, v, eligibleRateTopics)
}
// If access controls are enabled, only set rate visitor if
// - topic is reserved, and v.user is the owner
// - topic is not reserved, and v.user has write access
writableRateTopics := make([]*topic, 0)
for _, t := range topics {
ownerUserID, err := s.userManager.ReservationOwner(t.ID)
if err != nil {
return err
}
if ownerUserID == "" {
if err := s.userManager.Authorize(v.User(), t.ID, user.PermissionWrite); err == nil {
writableRateTopics = append(writableRateTopics, t)
}
} else if ownerUserID == v.MaybeUserID() {
writableRateTopics = append(writableRateTopics, t)
}
}
return s.setRateVisitors(r, v, writableRateTopics)
}
func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topic) error {
for _, t := range rateTopics {
logvr(v, r).
Tag(tagSubscribe).
Field("message_topic", t.ID).
Debug("Setting visitor as rate visitor for topic %s", t.ID)
t.SetRateVisitor(v)
}
return nil
} }
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the // sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the

View file

@ -2040,7 +2040,7 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
r.RemoteAddr = "1.2.3.4" r.RemoteAddr = "1.2.3.4"
} }
rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{ rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{
"rate-topics": "*", "rate-topics": "mytopic",
}, subscriberFn) }, subscriberFn)
require.Equal(t, 200, rr.Code) require.Equal(t, 200, rr.Code)
require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String()) require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String())
@ -2065,6 +2065,72 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
require.Nil(t, s.visitors["ip:1.2.3.4"]) require.Nil(t, s.visitors["ip:1.2.3.4"])
} }
func TestServer_SubscriberRateLimiting_ProtectedTopics(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
// Create some ACLs
require.Nil(t, s.userManager.AddTier(&user.Tier{
Code: "test",
MessageLimit: 5,
}))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("ben", "test"))
require.Nil(t, s.userManager.AllowAccess("ben", "announcements", user.PermissionReadWrite))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "public_topic", user.PermissionReadWrite))
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
require.Nil(t, s.userManager.AddReservation("phil", "reserved-for-phil", user.PermissionReadWrite))
// Set rate visitor as user "phil" on topic
// - "reserved-for-phil": Allowed, because I am the owner
// - "public_topic": Allowed, because it has read-write permissions for everyone
// - "announcements": NOT allowed, because it has read-only permissions for everyone
rr := request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
"Rate-Topics": "reserved-for-phil,public_topic,announcements",
})
require.Equal(t, 200, rr.Code)
require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name)
require.Equal(t, "phil", s.topics["public_topic"].rateVisitor.user.Name)
require.Nil(t, s.topics["announcements"].rateVisitor)
// Set rate visitor as user "ben" on topic
// - "reserved-for-phil": NOT allowed, because I am not the owner
// - "public_topic": Allowed, because it has read-write permissions for everyone
// - "announcements": Allowed, because I have read-write permissions
rr = request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
"Rate-Topics": "reserved-for-phil,public_topic,announcements",
})
require.Equal(t, 200, rr.Code)
require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name)
require.Equal(t, "ben", s.topics["public_topic"].rateVisitor.user.Name)
require.Equal(t, "ben", s.topics["announcements"].rateVisitor.user.Name)
}
func TestServer_SubscriberRateLimiting_ProtectedTopics_WithDefaultReadWrite(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionReadWrite
s := newTestServer(t, c)
// Create some ACLs
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
// Set rate visitor as ip:1.2.3.4 on topic
// - "up1234": Allowed, because no ACLs and nobody owns the topic
// - "announcements": NOT allowed, because it has read-only permissions for everyone
rr := request(t, s, "GET", "/up1234,announcements/json?poll=1", "", nil, func(r *http.Request) {
r.RemoteAddr = "1.2.3.4"
})
require.Equal(t, 200, rr.Code)
require.Equal(t, "1.2.3.4", s.topics["up1234"].rateVisitor.ip.String())
require.Nil(t, s.topics["announcements"].rateVisitor)
}
func newTestConfig(t *testing.T) *Config { func newTestConfig(t *testing.T) *Config {
conf := NewConfig() conf := NewConfig()
conf.BaseURL = "http://127.0.0.1:12345" conf.BaseURL = "http://127.0.0.1:12345"

View file

@ -141,6 +141,7 @@ func (v *visitor) Context() log.Context {
func (v *visitor) contextNoLock() log.Context { func (v *visitor) contextNoLock() log.Context {
info := v.infoLightNoLock() info := v.infoLightNoLock()
fields := log.Context{ fields := log.Context{
"visitor_id": visitorID(v.ip, v.user),
"visitor_ip": v.ip.String(), "visitor_ip": v.ip.String(),
"visitor_messages": info.Stats.Messages, "visitor_messages": info.Stats.Messages,
"visitor_messages_limit": info.Limits.MessageLimit, "visitor_messages_limit": info.Limits.MessageLimit,

View file

@ -201,7 +201,14 @@ const (
selectUserReservationsCountQuery = ` selectUserReservationsCountQuery = `
SELECT COUNT(*) SELECT COUNT(*)
FROM user_access FROM user_access
WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?) WHERE user_id = owner_user_id
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
`
selectUserReservationsOwnerQuery = `
SELECT owner_user_id
FROM user_access
WHERE topic = ?
AND user_id = owner_user_id
` `
selectUserHasReservationQuery = ` selectUserHasReservationQuery = `
SELECT COUNT(*) SELECT COUNT(*)
@ -1025,6 +1032,24 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
return count, nil return count, nil
} }
// ReservationOwner returns user ID of the user that owns this topic, or an
// empty string if it's not owned by anyone
func (a *Manager) ReservationOwner(topic string) (string, error) {
rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic)
if err != nil {
return "", err
}
defer rows.Close()
if !rows.Next() {
return "", nil
}
var ownerUserID string
if err := rows.Scan(&ownerUserID); err != nil {
return "", err
}
return ownerUserID, nil
}
// ChangePassword changes a user's password // ChangePassword changes a user's password
func (a *Manager) ChangePassword(username, password string) error { func (a *Manager) ChangePassword(username, password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost) hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)