WIP: Auth in 80 lines of code :-)

This commit is contained in:
Philipp Heckel 2022-01-21 22:22:27 -05:00
parent aab705f4a4
commit 2181227a6e
3 changed files with 83 additions and 12 deletions

31
server/auth_simple.go Normal file
View file

@ -0,0 +1,31 @@
package server
/*
sqlite> create table user (id int auto increment, user text, password text not null);
sqlite> create table user_topic (user_id int not null, topic text not null, allow_write int, allow_read int);
sqlite> create table topic (topic text primary key, allow_anonymous_write int, allow_anonymous_read int);
*/
const (
permRead = 1
permWrite = 2
)
type auther interface {
Authenticate(user, pass string) bool
Authorize(user, topic string, perm int) bool
}
type memAuther struct {
}
func (m memAuther) Authenticate(user, pass string) bool {
return user == "phil" && pass == "phil"
}
func (m memAuther) Authorize(user, topic string, perm int) bool {
if perm == permRead {
return true
}
return user == "phil" && topic == "mytopic"
}

View file

@ -40,6 +40,7 @@ var (
errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""} errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""}
errHTTPBadRequestWebSocketsUpgradeHeaderMissing = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", ""} errHTTPBadRequestWebSocketsUpgradeHeaderMissing = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", ""}
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", ""}
errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}

View file

@ -46,6 +46,7 @@ type Server struct {
firebase subscriber firebase subscriber
mailer mailer mailer mailer
messages int64 messages int64
auther auther
cache cache cache cache
fileCache *fileCache fileCache *fileCache
closeChan chan bool closeChan chan bool
@ -57,6 +58,9 @@ type indexPage struct {
CacheDuration time.Duration CacheDuration time.Duration
} }
// handleFunc extends the normal http.HandlerFunc to be able to easily return errors
type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error
var ( var (
topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /! topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /!
topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app!
@ -144,6 +148,7 @@ func New(conf *Config) (*Server, error) {
firebase: firebaseSubscriber, firebase: firebaseSubscriber,
mailer: mailer, mailer: mailer,
topics: topics, topics: topics,
auther: &memAuther{},
visitors: make(map[string]*visitor), visitors: make(map[string]*visitor),
}, nil }, nil
} }
@ -312,6 +317,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
v := s.visitor(r)
if r.Method == http.MethodGet && r.URL.Path == "/" { if r.Method == http.MethodGet && r.URL.Path == "/" {
return s.handleHome(w, r) return s.handleHome(w, r)
} else if r.Method == http.MethodGet && r.URL.Path == "/example.html" { } else if r.Method == http.MethodGet && r.URL.Path == "/example.html" {
@ -323,23 +329,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error {
} else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) {
return s.handleDocs(w, r) return s.handleDocs(w, r)
} else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" { } else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" {
return s.withRateLimit(w, r, s.handleFile) return s.limitRequests(s.handleFile)(w, r, v)
} else if r.Method == http.MethodOptions { } else if r.Method == http.MethodOptions {
return s.handleOptions(w, r) return s.handleOptions(w, r)
} else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) {
return s.handleTopic(w, r) return s.handleTopic(w, r)
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) { } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handlePublish) return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handlePublish) return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeJSON) return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeSSE) return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeRaw) return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) { } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
return s.withRateLimit(w, r, s.handleSubscribeWS) return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
} }
return errHTTPNotFound return errHTTPNotFound
} }
@ -1094,12 +1100,45 @@ func (s *Server) sendDelayedMessages() error {
return nil return nil
} }
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error { func (s *Server) limitRequests(next handleFunc) handleFunc {
v := s.visitor(r) return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if err := v.RequestAllowed(); err != nil { if err := v.RequestAllowed(); err != nil {
return errHTTPTooManyRequestsLimitRequests return errHTTPTooManyRequestsLimitRequests
}
return next(w, r, v)
}
}
func (s *Server) authWrite(next handleFunc) handleFunc {
return s.withAuth(next, permWrite)
}
func (s *Server) authRead(next handleFunc) handleFunc {
return s.withAuth(next, permRead)
}
func (s *Server) withAuth(next handleFunc, perm int) handleFunc {
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
if s.auther == nil {
return next(w, r, v)
}
t, err := s.topicFromPath(r.URL.Path)
if err != nil {
return err
}
user, pass, ok := r.BasicAuth()
if ok {
if !s.auther.Authenticate(user, pass) {
return errHTTPUnauthorized
}
} else {
user = "" // Just in case
}
if !s.auther.Authorize(user, t.ID, perm) {
return errHTTPUnauthorized
}
return next(w, r, v)
} }
return handler(w, r, v)
} }
// visitor creates or retrieves a rate.Limiter for the given visitor. // visitor creates or retrieves a rate.Limiter for the given visitor.