diff --git a/server/auth_simple.go b/server/auth_simple.go new file mode 100644 index 0000000..1d24bad --- /dev/null +++ b/server/auth_simple.go @@ -0,0 +1,31 @@ +package server + +/* +sqlite> create table user (id int auto increment, user text, password text not null); +sqlite> create table user_topic (user_id int not null, topic text not null, allow_write int, allow_read int); +sqlite> create table topic (topic text primary key, allow_anonymous_write int, allow_anonymous_read int); +*/ + +const ( + permRead = 1 + permWrite = 2 +) + +type auther interface { + Authenticate(user, pass string) bool + Authorize(user, topic string, perm int) bool +} + +type memAuther struct { +} + +func (m memAuther) Authenticate(user, pass string) bool { + return user == "phil" && pass == "phil" +} + +func (m memAuther) Authorize(user, topic string, perm int) bool { + if perm == permRead { + return true + } + return user == "phil" && topic == "mytopic" +} diff --git a/server/errors.go b/server/errors.go index c776271..0ad81de 100644 --- a/server/errors.go +++ b/server/errors.go @@ -40,6 +40,7 @@ var ( errHTTPBadRequestAttachmentsExpiryBeforeDelivery = &errHTTP{40015, http.StatusBadRequest, "invalid request: attachment expiry before delayed delivery date", ""} errHTTPBadRequestWebSocketsUpgradeHeaderMissing = &errHTTP{40016, http.StatusBadRequest, "invalid request: client not using the websocket protocol", ""} errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""} + errHTTPUnauthorized = &errHTTP{40101, http.StatusUnauthorized, "unauthorized", ""} errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"} errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"} diff --git a/server/server.go b/server/server.go index 4a59bc6..935bc78 100644 --- a/server/server.go +++ b/server/server.go @@ -46,6 +46,7 @@ type Server struct { firebase subscriber mailer mailer messages int64 + auther auther cache cache fileCache *fileCache closeChan chan bool @@ -57,6 +58,9 @@ type indexPage struct { CacheDuration time.Duration } +// handleFunc extends the normal http.HandlerFunc to be able to easily return errors +type handleFunc func(http.ResponseWriter, *http.Request, *visitor) error + var ( topicRegex = regexp.MustCompile(`^[-_A-Za-z0-9]{1,64}$`) // No /! topicPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}$`) // Regex must match JS & Android app! @@ -144,6 +148,7 @@ func New(conf *Config) (*Server, error) { firebase: firebaseSubscriber, mailer: mailer, topics: topics, + auther: &memAuther{}, visitors: make(map[string]*visitor), }, nil } @@ -312,6 +317,7 @@ func (s *Server) handle(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { + v := s.visitor(r) if r.Method == http.MethodGet && r.URL.Path == "/" { return s.handleHome(w, r) } else if r.Method == http.MethodGet && r.URL.Path == "/example.html" { @@ -323,23 +329,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { return s.handleDocs(w, r) } else if r.Method == http.MethodGet && fileRegex.MatchString(r.URL.Path) && s.config.AttachmentCacheDir != "" { - return s.withRateLimit(w, r, s.handleFile) + return s.limitRequests(s.handleFile)(w, r, v) } else if r.Method == http.MethodOptions { return s.handleOptions(w, r) } else if r.Method == http.MethodGet && topicPathRegex.MatchString(r.URL.Path) { return s.handleTopic(w, r) } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) { - return s.withRateLimit(w, r, s.handlePublish) + return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v) } else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) { - return s.withRateLimit(w, r, s.handlePublish) + return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v) } else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) { - return s.withRateLimit(w, r, s.handleSubscribeJSON) + return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v) } else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) { - return s.withRateLimit(w, r, s.handleSubscribeSSE) + return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v) } else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) { - return s.withRateLimit(w, r, s.handleSubscribeRaw) + return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v) } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) { - return s.withRateLimit(w, r, s.handleSubscribeWS) + return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v) } return errHTTPNotFound } @@ -1094,12 +1100,45 @@ func (s *Server) sendDelayedMessages() error { return nil } -func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error { - v := s.visitor(r) - if err := v.RequestAllowed(); err != nil { - return errHTTPTooManyRequestsLimitRequests +func (s *Server) limitRequests(next handleFunc) handleFunc { + return func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if err := v.RequestAllowed(); err != nil { + return errHTTPTooManyRequestsLimitRequests + } + return next(w, r, v) + } +} + +func (s *Server) authWrite(next handleFunc) handleFunc { + return s.withAuth(next, permWrite) +} + +func (s *Server) authRead(next handleFunc) handleFunc { + return s.withAuth(next, permRead) +} + +func (s *Server) withAuth(next handleFunc, perm int) handleFunc { + return func(w http.ResponseWriter, r *http.Request, v *visitor) error { + if s.auther == nil { + return next(w, r, v) + } + t, err := s.topicFromPath(r.URL.Path) + if err != nil { + return err + } + user, pass, ok := r.BasicAuth() + if ok { + if !s.auther.Authenticate(user, pass) { + return errHTTPUnauthorized + } + } else { + user = "" // Just in case + } + if !s.auther.Authorize(user, t.ID, perm) { + return errHTTPUnauthorized + } + return next(w, r, v) } - return handler(w, r, v) } // visitor creates or retrieves a rate.Limiter for the given visitor.