WIP: Auth in 80 lines of code :-)
This commit is contained in:
parent
aab705f4a4
commit
2181227a6e
3 changed files with 83 additions and 12 deletions
31
server/auth_simple.go
Normal file
31
server/auth_simple.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package server
|
||||||
|
|
||||||
|
/*
|
||||||
|
sqlite> create table user (id int auto increment, user text, password text not null);
|
||||||
|
sqlite> create table user_topic (user_id int not null, topic text not null, allow_write int, allow_read int);
|
||||||
|
sqlite> create table topic (topic text primary key, allow_anonymous_write int, allow_anonymous_read int);
|
||||||
|
*/
|
||||||
|
|
||||||
|
const (
|
||||||
|
permRead = 1
|
||||||
|
permWrite = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
type auther interface {
|
||||||
|
Authenticate(user, pass string) bool
|
||||||
|
Authorize(user, topic string, perm int) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type memAuther struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m memAuther) Authenticate(user, pass string) bool {
|
||||||
|
return user == "phil" && pass == "phil"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m memAuther) Authorize(user, topic string, perm int) bool {
|
||||||
|
if perm == permRead {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return user == "phil" && topic == "mytopic"
|
||||||
|
}
|
|
@ -40,6 +40,7 @@ var (
|
||||||
errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""}
|
errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""}
|
||||||
errHTTPBadRequestWebSocketsUpgradeHeaderMissing = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", ""}
|
errHTTPBadRequestWebSocketsUpgradeHeaderMissing = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", ""}
|
||||||
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
|
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
|
||||||
|
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", ""}
|
||||||
errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
|
|
|
@ -46,6 +46,7 @@ type Server struct {
|
||||||
firebase subscriber
|
firebase subscriber
|
||||||
mailer mailer
|
mailer mailer
|
||||||
messages int64
|
messages int64
|
||||||
|
auther auther
|
||||||
cache cache
|
cache cache
|
||||||
fileCache *fileCache
|
fileCache *fileCache
|
||||||
closeChan chan bool
|
closeChan chan bool
|
||||||
|
@ -57,6 +58,9 @@ type indexPage struct {
|
||||||
CacheDuration time.Duration
|
CacheDuration time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
|
||||||
|
type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
|
||||||
|
|
||||||
var (
|
var (
|
||||||
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
|
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
|
||||||
topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
|
topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
|
||||||
|
@ -144,6 +148,7 @@ func New(conf *Config) (*Server, error) {
|
||||||
firebase: firebaseSubscriber,
|
firebase: firebaseSubscriber,
|
||||||
mailer: mailer,
|
mailer: mailer,
|
||||||
topics: topics,
|
topics: topics,
|
||||||
|
auther: &memAuther{},
|
||||||
visitors: make(map[string]*visitor),
|
visitors: make(map[string]*visitor),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
@ -312,6 +317,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
v := s.visitor(r)
|
||||||
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
if r.Method == http.MethodGet && r.URL.Path == "/" {
|
||||||
return s.handleHome(w, r)
|
return s.handleHome(w, r)
|
||||||
} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
|
} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
|
||||||
|
@ -323,23 +329,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
|
||||||
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
|
||||||
return s.handleDocs(w, r)
|
return s.handleDocs(w, r)
|
||||||
} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
|
} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
|
||||||
return s.withRateLimit(w, r, s.handleFile)
|
return s.limitRequests(s.handleFile)(w, r, v)
|
||||||
} else if r.Method == http.MethodOptions {
|
} else if r.Method == http.MethodOptions {
|
||||||
return s.handleOptions(w, r)
|
return s.handleOptions(w, r)
|
||||||
} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
|
||||||
return s.handleTopic(w, r)
|
return s.handleTopic(w, r)
|
||||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
|
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
|
||||||
return s.withRateLimit(w, r, s.handlePublish)
|
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
|
||||||
return s.withRateLimit(w, r, s.handlePublish)
|
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
|
||||||
return s.withRateLimit(w, r, s.handleSubscribeJSON)
|
return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
|
||||||
return s.withRateLimit(w, r, s.handleSubscribeSSE)
|
return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
|
||||||
return s.withRateLimit(w, r, s.handleSubscribeRaw)
|
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
|
||||||
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
|
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
|
||||||
return s.withRateLimit(w, r, s.handleSubscribeWS)
|
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
|
||||||
}
|
}
|
||||||
return errHTTPNotFound
|
return errHTTPNotFound
|
||||||
}
|
}
|
||||||
|
@ -1094,12 +1100,45 @@ func (s *Server) sendDelayedMessages() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
|
func (s *Server) limitRequests(next handleFunc) handleFunc {
|
||||||
v := s.visitor(r)
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
if err := v.RequestAllowed(); err != nil {
|
if err := v.RequestAllowed(); err != nil {
|
||||||
return errHTTPTooManyRequestsLimitRequests
|
return errHTTPTooManyRequestsLimitRequests
|
||||||
|
}
|
||||||
|
return next(w, r, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) authWrite(next handleFunc) handleFunc {
|
||||||
|
return s.withAuth(next, permWrite)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) authRead(next handleFunc) handleFunc {
|
||||||
|
return s.withAuth(next, permRead)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) withAuth(next handleFunc, perm int) handleFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||||
|
if s.auther == nil {
|
||||||
|
return next(w, r, v)
|
||||||
|
}
|
||||||
|
t, err := s.topicFromPath(r.URL.Path)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
user, pass, ok := r.BasicAuth()
|
||||||
|
if ok {
|
||||||
|
if !s.auther.Authenticate(user, pass) {
|
||||||
|
return errHTTPUnauthorized
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
user = "" // Just in case
|
||||||
|
}
|
||||||
|
if !s.auther.Authorize(user, t.ID, perm) {
|
||||||
|
return errHTTPUnauthorized
|
||||||
|
}
|
||||||
|
return next(w, r, v)
|
||||||
}
|
}
|
||||||
return handler(w, r, v)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
// visitor creates or retrieves a rate.Limiter for the given visitor.
|
||||||
|
|
Loading…
Reference in a new issue