WIP: Access API

This commit is contained in:
binwiederhier 2023-05-13 14:39:31 -04:00
parent bd81aef1c9
commit 625b13280f
10 changed files with 250 additions and 46 deletions

View file

@ -82,6 +82,8 @@ var (
apiHealthPath = "/v1/health"
apiStatsPath = "/v1/stats"
apiTiersPath = "/v1/tiers"
apiUserPath = "/v1/user"
apiAccessPath = "/v1/access"
apiAccountPath = "/v1/account"
apiAccountTokenPath = "/v1/account/token"
apiAccountPasswordPath = "/v1/account/password"
@ -411,6 +413,10 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
return s.handleHealth(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == webConfigPath {
return s.ensureWebEnabled(s.handleWebConfig)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccessPath {
return s.ensureAdmin(s.handleAccessAllow)(w, r, v)
} else if r.Method == http.MethodDelete && r.URL.Path == apiAccessPath {
return s.ensureAdmin(s.handleAccessReset)(w, r, v)
} else if r.Method == http.MethodPost && r.URL.Path == apiAccountPath {
return s.ensureUserManager(s.handleAccountCreate)(w, r, v)
} else if r.Method == http.MethodGet && r.URL.Path == apiAccountPath {
@ -1192,7 +1198,7 @@ func (s *Server) handleSubscribeWS(w http.ResponseWriter, r *http.Request, v *vi
}
defer conn.Close()
// Subscription connections can be canceled externally, see topic.CancelSubscribers
// Subscription connections can be canceled externally, see topic.CancelSubscribersExceptUser
cancelCtx, cancel := context.WithCancel(context.Background())
defer cancel()
@ -1434,6 +1440,7 @@ func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request, _ *visito
return nil
}
// topicFromPath returns the topic from a root path (e.g. /mytopic), creating it if it doesn't exist.
func (s *Server) topicFromPath(path string) (*topic, error) {
parts := strings.Split(path, "/")
if len(parts) < 2 {
@ -1442,6 +1449,7 @@ func (s *Server) topicFromPath(path string) (*topic, error) {
return s.topicFromID(parts[1])
}
// topicFromID returns the topic from a root path (e.g. /mytopic,mytopic2), creating it if it doesn't exist.
func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
parts := strings.Split(path, "/")
if len(parts) < 2 {
@ -1455,6 +1463,7 @@ func (s *Server) topicsFromPath(path string) ([]*topic, string, error) {
return topics, parts[1], nil
}
// topicsFromIDs returns the topics with the given IDs, creating them if they don't exist.
func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
s.mu.Lock()
defer s.mu.Unlock()
@ -1474,6 +1483,7 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
return topics, nil
}
// topicFromID returns the topic with the given ID, creating it if it doesn't exist.
func (s *Server) topicFromID(id string) (*topic, error) {
topics, err := s.topicsFromIDs(id)
if err != nil {
@ -1482,6 +1492,23 @@ func (s *Server) topicFromID(id string) (*topic, error) {
return topics[0], nil
}
// topicsFromPattern returns a list of topics matching the given pattern, but it does not create them.
func (s *Server) topicsFromPattern(pattern string) ([]*topic, error) {
s.mu.RLock()
defer s.mu.RUnlock()
patternRegexp, err := regexp.Compile("^" + strings.ReplaceAll(pattern, "*", ".*") + "$")
if err != nil {
return nil, err
}
topics := make([]*topic, 0)
for _, t := range s.topics {
if patternRegexp.MatchString(t.ID) {
topics = append(topics, t)
}
}
return topics, nil
}
func (s *Server) runSMTPServer() error {
s.smtpServerBackend = newMailBackend(s.config, s.handle)
s.smtpServer = smtp.NewServer(s.smtpServerBackend)

50
server/server_access.go Normal file
View file

@ -0,0 +1,50 @@
package server
import (
"heckel.io/ntfy/user"
"net/http"
)
func (s *Server) handleAccessAllow(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccessAllowRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
}
permission, err := user.ParsePermission(req.Permission)
if err != nil {
return errHTTPBadRequestPermissionInvalid
}
if err := s.userManager.AllowAccess(req.Username, req.Topic, permission); err != nil {
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) handleAccessReset(w http.ResponseWriter, r *http.Request, v *visitor) error {
req, err := readJSONWithLimit[apiAccessResetRequest](r.Body, jsonBodyBytesLimit, false)
if err != nil {
return err
}
u, err := s.userManager.User(req.Username)
if err != nil {
return err
}
if err := s.userManager.ResetAccess(req.Username, req.Topic); err != nil {
return err
}
if err := s.killUserSubscriber(u, req.Topic); err != nil { // This may be a pattern
return err
}
return s.writeJSON(w, newSuccessResponse())
}
func (s *Server) killUserSubscriber(u *user.User, topicPattern string) error {
topics, err := s.topicsFromPattern(topicPattern)
if err != nil {
return err
}
for _, t := range topics {
t.CancelSubscriberUser(u.ID)
}
return nil
}

View file

@ -0,0 +1,100 @@
package server
import (
"github.com/stretchr/testify/require"
"heckel.io/ntfy/user"
"heckel.io/ntfy/util"
"sync/atomic"
"testing"
"time"
)
func TestAccess_AllowReset(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
// Subscribing not allowed
rr := request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
// Grant access
rr = request(t, s, "POST", "/v1/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Now subscribing is allowed
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
// Reset access
rr = request(t, s, "DELETE", "/v1/access", `{"username": "ben", "topic":"gold"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Subscribing not allowed (again)
rr = request(t, s, "GET", "/gold/json?poll=1", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 403, rr.Code)
}
func TestAccess_AllowReset_NonAdminAttempt(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
// Grant access fails, because non-admin
rr := request(t, s, "POST", "/v1/access", `{"username": "ben", "topic":"gold", "permission":"ro"}`, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 401, rr.Code)
}
func TestAccess_AllowReset_KillConnection(t *testing.T) {
c := newTestConfigWithAuthFile(t)
c.AuthDefault = user.PermissionDenyAll
s := newTestServer(t, c)
defer s.closeDatabases()
// User and admin, grant access to "gol*" topics
require.Nil(t, s.userManager.AddUser("phil", "phil", user.RoleAdmin))
require.Nil(t, s.userManager.AddUser("ben", "ben", user.RoleUser))
require.Nil(t, s.userManager.AllowAccess("ben", "gol*", user.PermissionRead)) // Wildcard!
start, timeTaken := time.Now(), atomic.Int64{}
go func() {
rr := request(t, s, "GET", "/gold/json", "", map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})
require.Equal(t, 200, rr.Code)
timeTaken.Store(time.Since(start).Milliseconds())
}()
time.Sleep(500 * time.Millisecond)
// Reset access
rr := request(t, s, "DELETE", "/v1/access", `{"username": "ben", "topic":"gol*"}`, map[string]string{
"Authorization": util.BasicAuth("phil", "phil"),
})
require.Equal(t, 200, rr.Code)
// Wait for connection to be killed; this will fail if the connection is never killed
waitFor(t, func() bool {
return timeTaken.Load() >= 500
})
}

View file

@ -444,7 +444,7 @@ func (s *Server) handleAccountReservationAdd(w http.ResponseWriter, r *http.Requ
if err != nil {
return err
}
t.CancelSubscribers(u.ID)
t.CancelSubscribersExceptUser(u.ID)
return s.writeJSON(w, newSuccessResponse())
}

View file

@ -76,6 +76,15 @@ func (s *Server) ensureUser(next handleFunc) handleFunc {
})
}
func (s *Server) ensureAdmin(next handleFunc) handleFunc {
return s.ensureUserManager(func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if !v.User().IsAdmin() {
return errHTTPUnauthorized
}
return next(w, r, v)
})
}
func (s *Server) ensurePaymentsEnabled(next handleFunc) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.config.StripeSecretKey == "" || s.stripe == nil {

View file

@ -141,24 +141,40 @@ func (t *topic) Keepalive() {
t.lastAccess = time.Now()
}
// CancelSubscribers calls the cancel function for all subscribers, forcing
func (t *topic) CancelSubscribers(exceptUserID string) {
// CancelSubscribersExceptUser calls the cancel function for all subscribers, forcing
func (t *topic) CancelSubscribersExceptUser(exceptUserID string) {
t.mu.Lock()
defer t.mu.Unlock()
for _, s := range t.subscribers {
if s.userID != exceptUserID {
log.
Tag(tagSubscribe).
With(t).
Fields(log.Context{
"user_id": s.userID,
}).
Debug("Canceling subscriber %s", s.userID)
s.cancel()
t.cancelUserSubscriber(s)
}
}
}
// CancelSubscriberUser kills the subscriber with the given user ID
func (t *topic) CancelSubscriberUser(userID string) {
t.mu.RLock()
defer t.mu.RUnlock()
for _, s := range t.subscribers {
if s.userID == userID {
t.cancelUserSubscriber(s)
return
}
}
}
func (t *topic) cancelUserSubscriber(s *topicSubscriber) {
log.
Tag(tagSubscribe).
With(t).
Fields(log.Context{
"user_id": s.userID,
}).
Debug("Canceling subscriber with user ID %s", s.userID)
s.cancel()
}
func (t *topic) Context() log.Context {
t.mu.RLock()
defer t.mu.RUnlock()

View file

@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/require"
)
func TestTopic_CancelSubscribers(t *testing.T) {
func TestTopic_CancelSubscribersExceptUser(t *testing.T) {
t.Parallel()
subFn := func(v *visitor, msg *message) error {
@ -27,11 +27,34 @@ func TestTopic_CancelSubscribers(t *testing.T) {
to.Subscribe(subFn, "", cancelFn1)
to.Subscribe(subFn, "u_phil", cancelFn2)
to.CancelSubscribers("u_phil")
to.CancelSubscribersExceptUser("u_phil")
require.True(t, canceled1.Load())
require.False(t, canceled2.Load())
}
func TestTopic_CancelSubscribersUser(t *testing.T) {
t.Parallel()
subFn := func(v *visitor, msg *message) error {
return nil
}
canceled1 := atomic.Bool{}
cancelFn1 := func() {
canceled1.Store(true)
}
canceled2 := atomic.Bool{}
cancelFn2 := func() {
canceled2.Store(true)
}
to := newTopic("mytopic")
to.Subscribe(subFn, "u_another", cancelFn1)
to.Subscribe(subFn, "u_phil", cancelFn2)
to.CancelSubscriberUser("u_phil")
require.False(t, canceled1.Load())
require.True(t, canceled2.Load())
}
func TestTopic_Keepalive(t *testing.T) {
t.Parallel()

View file

@ -244,6 +244,17 @@ type apiStatsResponse struct {
MessagesRate float64 `json:"messages_rate"` // Average number of messages per second
}
type apiAccessAllowRequest struct {
Username string `json:"username"`
Topic string `json:"topic"`
Permission string `json:"permission"`
}
type apiAccessResetRequest struct {
Username string `json:"username"`
Topic string `json:"topic"`
}
type apiAccountCreateRequest struct {
Username string `json:"username"`
Password string `json:"password"`