Only set rate visitor if allowed
This commit is contained in:
parent
2329695a47
commit
bfc3983d06
4 changed files with 151 additions and 17 deletions
|
@ -112,7 +112,6 @@ const (
|
||||||
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
encodingBase64 = "base64" // Used mainly for binary UnifiedPush messages
|
||||||
jsonBodyBytesLimit = 16384
|
jsonBodyBytesLimit = 16384
|
||||||
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
|
unifiedPushTopicPrefix = "up" // Temporarily, we rate limit all "up*" topics based on the subscriber
|
||||||
rateTopicsWildcard = "*" // Allows defining all topics in the request subscriber-rate-limited topics
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// WebSocket constants
|
// WebSocket constants
|
||||||
|
@ -977,7 +976,9 @@ func (s *Server) handleSubscribeHTTP(w http.ResponseWriter, r *http.Request, v *
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
registerRateVisitors(topics, rateTopics, v)
|
if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
w.Header().Set("Content-Type", contentType+"; charset=utf-8") // Android/Volley client needs charset!
|
||||||
if poll {
|
if poll {
|
||||||
|
@ -1113,7 +1114,9 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
|
||||||
}
|
}
|
||||||
return conn.WriteJSON(msg)
|
return conn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
registerRateVisitors(topics, rateTopics, v)
|
if err := s.maybeSetRateVisitors(r, v, topics, rateTopics); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
w.Header().Set("Access-Control-Allow-Origin", s.config.AccessControlAllowOrigin) // CORS, allow cross-origin requests
|
||||||
if poll {
|
if poll {
|
||||||
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
return s.sendOldMessages(topics, since, scheduled, v, sub)
|
||||||
|
@ -1156,23 +1159,62 @@ func parseSubscribeParams(r *http.Request) (poll bool, since sinceMarker, schedu
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerRateVisitors sets the rate visitor on a topic, indicating that all messages published to that topic
|
// maybeSetRateVisitors sets the rate visitor on a topic (v.SetRateVisitor), indicating that all messages published
|
||||||
// will be rate limited against the rate visitor instead of the publishing visitor.
|
// to that topic will be rate limited against the rate visitor instead of the publishing visitor.
|
||||||
|
//
|
||||||
|
// Setting the rate visitor is ony allowed if
|
||||||
|
// - auth-file is not set (everything is open by default)
|
||||||
|
// - the topic is reserved, and v.user is the owner
|
||||||
|
// - the topic is not reserved, and v.user has write access
|
||||||
//
|
//
|
||||||
// Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
|
// Note: This TEMPORARILY also registers all topics starting with "up" (= UnifiedPush). This is to ease the transition
|
||||||
// until the Android app will send the "Rate-Topics" header.
|
// until the Android app will send the "Rate-Topics" header.
|
||||||
func registerRateVisitors(topics []*topic, rateTopics []string, v *visitor) {
|
func (s *Server) maybeSetRateVisitors(r *http.Request, v *visitor, topics []*topic, rateTopics []string) error {
|
||||||
if len(rateTopics) == 1 && rateTopics[0] == rateTopicsWildcard {
|
// Make a list of topics that we'll actually set the RateVisitor on
|
||||||
for _, t := range topics {
|
eligibleRateTopics := make([]*topic, 0)
|
||||||
t.SetRateVisitor(v)
|
for _, t := range topics {
|
||||||
}
|
if strings.HasPrefix(t.ID, unifiedPushTopicPrefix) || util.Contains(rateTopics, t.ID) {
|
||||||
} else {
|
eligibleRateTopics = append(eligibleRateTopics, t)
|
||||||
for _, t := range topics {
|
|
||||||
if util.Contains(rateTopics, t.ID) || strings.HasPrefix(t.ID, unifiedPushTopicPrefix) {
|
|
||||||
t.SetRateVisitor(v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if len(eligibleRateTopics) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If access controls are turned off, v has access to everything, and we can set the rate visitor
|
||||||
|
if s.userManager == nil {
|
||||||
|
return s.setRateVisitors(r, v, eligibleRateTopics)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If access controls are enabled, only set rate visitor if
|
||||||
|
// - topic is reserved, and v.user is the owner
|
||||||
|
// - topic is not reserved, and v.user has write access
|
||||||
|
writableRateTopics := make([]*topic, 0)
|
||||||
|
for _, t := range topics {
|
||||||
|
ownerUserID, err := s.userManager.ReservationOwner(t.ID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if ownerUserID == "" {
|
||||||
|
if err := s.userManager.Authorize(v.User(), t.ID, user.PermissionWrite); err == nil {
|
||||||
|
writableRateTopics = append(writableRateTopics, t)
|
||||||
|
}
|
||||||
|
} else if ownerUserID == v.MaybeUserID() {
|
||||||
|
writableRateTopics = append(writableRateTopics, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.setRateVisitors(r, v, writableRateTopics)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) setRateVisitors(r *http.Request, v *visitor, rateTopics []*topic) error {
|
||||||
|
for _, t := range rateTopics {
|
||||||
|
logvr(v, r).
|
||||||
|
Tag(tagSubscribe).
|
||||||
|
Field("message_topic", t.ID).
|
||||||
|
Debug("Setting visitor as rate visitor for topic %s", t.ID)
|
||||||
|
t.SetRateVisitor(v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
// sendOldMessages selects old messages from the messageCache and calls sub for each of them. It uses since as the
|
||||||
|
|
|
@ -2040,7 +2040,7 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
|
||||||
r.RemoteAddr = "1.2.3.4"
|
r.RemoteAddr = "1.2.3.4"
|
||||||
}
|
}
|
||||||
rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{
|
rr := request(t, s, "GET", "/mytopic/json?poll=1", "", map[string]string{
|
||||||
"rate-topics": "*",
|
"rate-topics": "mytopic",
|
||||||
}, subscriberFn)
|
}, subscriberFn)
|
||||||
require.Equal(t, 200, rr.Code)
|
require.Equal(t, 200, rr.Code)
|
||||||
require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String())
|
require.Equal(t, "1.2.3.4", s.topics["mytopic"].rateVisitor.ip.String())
|
||||||
|
@ -2065,6 +2065,72 @@ func TestServer_SubscriberRateLimiting_VisitorExpiration(t *testing.T) {
|
||||||
require.Nil(t, s.visitors["ip:1.2.3.4"])
|
require.Nil(t, s.visitors["ip:1.2.3.4"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServer_SubscriberRateLimiting_ProtectedTopics(t *testing.T) {
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.AuthDefault = user.PermissionDenyAll
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
// Create some ACLs
|
||||||
|
require.Nil(t, s.userManager.AddTier(&user.Tier{
|
||||||
|
Code: "test",
|
||||||
|
MessageLimit: 5,
|
||||||
|
}))
|
||||||
|
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
|
||||||
|
require.Nil(t, s.userManager.ChangeTier("ben", "test"))
|
||||||
|
require.Nil(t, s.userManager.AllowAccess("ben", "announcements", user.PermissionReadWrite))
|
||||||
|
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
|
||||||
|
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "public_topic", user.PermissionReadWrite))
|
||||||
|
|
||||||
|
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleUser))
|
||||||
|
require.Nil(t, s.userManager.ChangeTier("phil", "test"))
|
||||||
|
require.Nil(t, s.userManager.AddReservation("phil", "reserved-for-phil", user.PermissionReadWrite))
|
||||||
|
|
||||||
|
// Set rate visitor as user "phil" on topic
|
||||||
|
// - "reserved-for-phil": Allowed, because I am the owner
|
||||||
|
// - "public_topic": Allowed, because it has read-write permissions for everyone
|
||||||
|
// - "announcements": NOT allowed, because it has read-only permissions for everyone
|
||||||
|
rr := request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("phil", "phil"),
|
||||||
|
"Rate-Topics": "reserved-for-phil,public_topic,announcements",
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name)
|
||||||
|
require.Equal(t, "phil", s.topics["public_topic"].rateVisitor.user.Name)
|
||||||
|
require.Nil(t, s.topics["announcements"].rateVisitor)
|
||||||
|
|
||||||
|
// Set rate visitor as user "ben" on topic
|
||||||
|
// - "reserved-for-phil": NOT allowed, because I am not the owner
|
||||||
|
// - "public_topic": Allowed, because it has read-write permissions for everyone
|
||||||
|
// - "announcements": Allowed, because I have read-write permissions
|
||||||
|
rr = request(t, s, "GET", "/reserved-for-phil,public_topic,announcements/json?poll=1", "", map[string]string{
|
||||||
|
"Authorization": util.BasicAuth("ben", "ben"),
|
||||||
|
"Rate-Topics": "reserved-for-phil,public_topic,announcements",
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
require.Equal(t, "phil", s.topics["reserved-for-phil"].rateVisitor.user.Name)
|
||||||
|
require.Equal(t, "ben", s.topics["public_topic"].rateVisitor.user.Name)
|
||||||
|
require.Equal(t, "ben", s.topics["announcements"].rateVisitor.user.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServer_SubscriberRateLimiting_ProtectedTopics_WithDefaultReadWrite(t *testing.T) {
|
||||||
|
c := newTestConfigWithAuthFile(t)
|
||||||
|
c.AuthDefault = user.PermissionReadWrite
|
||||||
|
s := newTestServer(t, c)
|
||||||
|
|
||||||
|
// Create some ACLs
|
||||||
|
require.Nil(t, s.userManager.AllowAccess(user.Everyone, "announcements", user.PermissionRead))
|
||||||
|
|
||||||
|
// Set rate visitor as ip:1.2.3.4 on topic
|
||||||
|
// - "up1234": Allowed, because no ACLs and nobody owns the topic
|
||||||
|
// - "announcements": NOT allowed, because it has read-only permissions for everyone
|
||||||
|
rr := request(t, s, "GET", "/up1234,announcements/json?poll=1", "", nil, func(r *http.Request) {
|
||||||
|
r.RemoteAddr = "1.2.3.4"
|
||||||
|
})
|
||||||
|
require.Equal(t, 200, rr.Code)
|
||||||
|
require.Equal(t, "1.2.3.4", s.topics["up1234"].rateVisitor.ip.String())
|
||||||
|
require.Nil(t, s.topics["announcements"].rateVisitor)
|
||||||
|
}
|
||||||
|
|
||||||
func newTestConfig(t *testing.T) *Config {
|
func newTestConfig(t *testing.T) *Config {
|
||||||
conf := NewConfig()
|
conf := NewConfig()
|
||||||
conf.BaseURL = "http://127.0.0.1:12345"
|
conf.BaseURL = "http://127.0.0.1:12345"
|
||||||
|
|
|
@ -141,6 +141,7 @@ func (v *visitor) Context() log.Context {
|
||||||
func (v *visitor) contextNoLock() log.Context {
|
func (v *visitor) contextNoLock() log.Context {
|
||||||
info := v.infoLightNoLock()
|
info := v.infoLightNoLock()
|
||||||
fields := log.Context{
|
fields := log.Context{
|
||||||
|
"visitor_id": visitorID(v.ip, v.user),
|
||||||
"visitor_ip": v.ip.String(),
|
"visitor_ip": v.ip.String(),
|
||||||
"visitor_messages": info.Stats.Messages,
|
"visitor_messages": info.Stats.Messages,
|
||||||
"visitor_messages_limit": info.Limits.MessageLimit,
|
"visitor_messages_limit": info.Limits.MessageLimit,
|
||||||
|
|
|
@ -201,7 +201,14 @@ const (
|
||||||
selectUserReservationsCountQuery = `
|
selectUserReservationsCountQuery = `
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM user_access
|
FROM user_access
|
||||||
WHERE user_id = owner_user_id AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
WHERE user_id = owner_user_id
|
||||||
|
AND owner_user_id = (SELECT id FROM user WHERE user = ?)
|
||||||
|
`
|
||||||
|
selectUserReservationsOwnerQuery = `
|
||||||
|
SELECT owner_user_id
|
||||||
|
FROM user_access
|
||||||
|
WHERE topic = ?
|
||||||
|
AND user_id = owner_user_id
|
||||||
`
|
`
|
||||||
selectUserHasReservationQuery = `
|
selectUserHasReservationQuery = `
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
|
@ -1025,6 +1032,24 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReservationOwner returns user ID of the user that owns this topic, or an
|
||||||
|
// empty string if it's not owned by anyone
|
||||||
|
func (a *Manager) ReservationOwner(topic string) (string, error) {
|
||||||
|
rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if !rows.Next() {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
var ownerUserID string
|
||||||
|
if err := rows.Scan(&ownerUserID); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return ownerUserID, nil
|
||||||
|
}
|
||||||
|
|
||||||
// ChangePassword changes a user's password
|
// ChangePassword changes a user's password
|
||||||
func (a *Manager) ChangePassword(username, password string) error {
|
func (a *Manager) ChangePassword(username, password string) error {
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), a.bcryptCost)
|
||||||
|
|
Loading…
Reference in a new issue