forked from mirrors/ntfy
Merge branch 'main' into up
This commit is contained in:
commit
41514cd557
3 changed files with 69 additions and 39 deletions
|
@ -43,12 +43,19 @@ type Server struct {
|
||||||
|
|
||||||
// errHTTP is a generic HTTP error for any non-200 HTTP error
|
// errHTTP is a generic HTTP error for any non-200 HTTP error
|
||||||
type errHTTP struct {
|
type errHTTP struct {
|
||||||
Code int
|
Code int `json:"code,omitempty"`
|
||||||
Status string
|
HTTPCode int `json:"http"`
|
||||||
|
Message string `json:"error"`
|
||||||
|
Link string `json:"link,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e errHTTP) Error() string {
|
func (e errHTTP) Error() string {
|
||||||
return fmt.Sprintf("http: %s", e.Status)
|
return e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e errHTTP) JSON() string {
|
||||||
|
b, _ := json.Marshal(&e)
|
||||||
|
return string(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
type indexPage struct {
|
type indexPage struct {
|
||||||
|
@ -105,9 +112,22 @@ var (
|
||||||
docsStaticFs embed.FS
|
docsStaticFs embed.FS
|
||||||
docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
|
docsStaticCached = &util.CachingEmbedFS{ModTime: time.Now(), FS: docsStaticFs}
|
||||||
|
|
||||||
errHTTPBadRequest = &errHTTP{http.StatusBadRequest, http.StatusText(http.StatusBadRequest)}
|
errHTTPNotFound = &errHTTP{40401, http.StatusNotFound, "page not found", ""}
|
||||||
errHTTPNotFound = &errHTTP{http.StatusNotFound, http.StatusText(http.StatusNotFound)}
|
errHTTPTooManyRequestsLimitRequests = &errHTTP{42901, http.StatusTooManyRequests, "limit reached: too many requests, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
errHTTPTooManyRequests = &errHTTP{http.StatusTooManyRequests, http.StatusText(http.StatusTooManyRequests)}
|
errHTTPTooManyRequestsLimitEmails = &errHTTP{42902, http.StatusTooManyRequests, "limit reached: too many emails, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
|
errHTTPTooManyRequestsLimitSubscriptions = &errHTTP{42903, http.StatusTooManyRequests, "limit reached: too many active subscriptions, please be nice", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
|
errHTTPTooManyRequestsLimitGlobalTopics = &errHTTP{42904, http.StatusTooManyRequests, "limit reached: the total number of topics on the server has been reached, please contact the admin", "https://ntfy.sh/docs/publish/#limitations"}
|
||||||
|
errHTTPBadRequestEmailDisabled = &errHTTP{40001, http.StatusBadRequest, "e-mail notifications are not enabled", "https://ntfy.sh/docs/config/#e-mail-notifications"}
|
||||||
|
errHTTPBadRequestDelayNoCache = &errHTTP{40002, http.StatusBadRequest, "cannot disable cache for delayed message", ""}
|
||||||
|
errHTTPBadRequestDelayNoEmail = &errHTTP{40003, http.StatusBadRequest, "delayed e-mail notifications are not supported", ""}
|
||||||
|
errHTTPBadRequestDelayCannotParse = &errHTTP{40004, http.StatusBadRequest, "invalid delay parameter: unable to parse delay", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
|
||||||
|
errHTTPBadRequestDelayTooSmall = &errHTTP{40005, http.StatusBadRequest, "invalid delay parameter: too small, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
|
||||||
|
errHTTPBadRequestDelayTooLarge = &errHTTP{40006, http.StatusBadRequest, "invalid delay parameter: too large, please refer to the docs", "https://ntfy.sh/docs/publish/#scheduled-delivery"}
|
||||||
|
errHTTPBadRequestPriorityInvalid = &errHTTP{40007, http.StatusBadRequest, "invalid priority parameter", "https://ntfy.sh/docs/publish/#message-priority"}
|
||||||
|
errHTTPBadRequestSinceInvalid = &errHTTP{40008, http.StatusBadRequest, "invalid since parameter", "https://ntfy.sh/docs/subscribe/api/#fetch-cached-messages"}
|
||||||
|
errHTTPBadRequestTopicInvalid = &errHTTP{40009, http.StatusBadRequest, "invalid topic: path invalid", ""}
|
||||||
|
errHTTPBadRequestTopicDisallowed = &errHTTP{40010, http.StatusBadRequest, "invalid topic: topic name is disallowed", ""}
|
||||||
|
errHTTPInternalError = &errHTTP{50001, http.StatusInternalServerError, "internal server error", ""}
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -241,11 +261,16 @@ func (s *Server) Stop() {
|
||||||
|
|
||||||
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
func (s *Server) handle(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := s.handleInternal(w, r); err != nil {
|
if err := s.handleInternal(w, r); err != nil {
|
||||||
if e, ok := err.(*errHTTP); ok {
|
var e *errHTTP
|
||||||
s.fail(w, r, e.Code, e)
|
var ok bool
|
||||||
} else {
|
if e, ok = err.(*errHTTP); !ok {
|
||||||
s.fail(w, r, http.StatusInternalServerError, err)
|
e = errHTTPInternalError
|
||||||
}
|
}
|
||||||
|
log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, e.HTTPCode, err.Error())
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests
|
||||||
|
w.WriteHeader(e.HTTPCode)
|
||||||
|
io.WriteString(w, e.JSON()+"\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -315,7 +340,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m := newDefaultMessage(t.ID, strings.TrimSpace(string(b)))
|
m := newDefaultMessage(t.ID, strings.TrimSpace(string(b)))
|
||||||
cache, firebase, email, unifiedpush, err := s.parseParams(r, m)
|
cache, firebase, email, err := s.parseParams(r, m)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -329,13 +354,13 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
|
|
||||||
if email != "" {
|
if email != "" {
|
||||||
if err := v.EmailAllowed(); err != nil {
|
if err := v.EmailAllowed(); err != nil {
|
||||||
return err
|
return errHTTPTooManyRequestsLimitEmails
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.UnifiedPush = unifiedpush
|
m.UnifiedPush = unifiedpush
|
||||||
if s.mailer == nil && email != "" {
|
if s.mailer == nil && email != "" {
|
||||||
return errHTTPBadRequest
|
return errHTTPBadRequestEmailDisabled
|
||||||
}
|
}
|
||||||
if m.Message == "" {
|
if m.Message == "" {
|
||||||
m.Message = emptyMessageBody
|
m.Message = emptyMessageBody
|
||||||
|
@ -376,11 +401,10 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, unifiedpush bool, err error) {
|
func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) {
|
||||||
cache = readParam(r, "x-cache", "cache") != "no"
|
cache = readParam(r, "x-cache", "cache") != "no"
|
||||||
firebase = readParam(r, "x-firebase", "firebase") != "no"
|
firebase = readParam(r, "x-firebase", "firebase") != "no"
|
||||||
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e")
|
||||||
unifiedpush = readParam(r, "up", "unifiedpush") == "1"
|
|
||||||
m.Title = readParam(r, "x-title", "title", "t")
|
m.Title = readParam(r, "x-title", "title", "t")
|
||||||
messageStr := readParam(r, "x-message", "message", "m")
|
messageStr := readParam(r, "x-message", "message", "m")
|
||||||
if messageStr != "" {
|
if messageStr != "" {
|
||||||
|
@ -388,7 +412,7 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase
|
||||||
}
|
}
|
||||||
m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
|
m.Priority, err = util.ParsePriority(readParam(r, "x-priority", "priority", "prio", "p"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, "", false, errHTTPBadRequest
|
return false, false, "", errHTTPBadRequestPriorityInvalid
|
||||||
}
|
}
|
||||||
tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
|
tagsStr := readParam(r, "x-tags", "tags", "tag", "ta")
|
||||||
if tagsStr != "" {
|
if tagsStr != "" {
|
||||||
|
@ -400,22 +424,22 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase
|
||||||
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
|
delayStr := readParam(r, "x-delay", "delay", "x-at", "at", "x-in", "in")
|
||||||
if delayStr != "" {
|
if delayStr != "" {
|
||||||
if !cache {
|
if !cache {
|
||||||
return false, false, "", false, errHTTPBadRequest
|
return false, false, "", errHTTPBadRequestDelayNoCache
|
||||||
}
|
}
|
||||||
if email != "" {
|
if email != "" {
|
||||||
return false, false, "", false, errHTTPBadRequest // we cannot store the email address (yet)
|
return false, false, "", errHTTPBadRequestDelayNoEmail // we cannot store the email address (yet)
|
||||||
}
|
}
|
||||||
delay, err := util.ParseFutureTime(delayStr, time.Now())
|
delay, err := util.ParseFutureTime(delayStr, time.Now())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, false, "", false, errHTTPBadRequest
|
return false, false, "", errHTTPBadRequestDelayCannotParse
|
||||||
} else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
|
} else if delay.Unix() < time.Now().Add(s.config.MinDelay).Unix() {
|
||||||
return false, false, "", false, errHTTPBadRequest
|
return false, false, "", errHTTPBadRequestDelayTooSmall
|
||||||
} else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
|
} else if delay.Unix() > time.Now().Add(s.config.MaxDelay).Unix() {
|
||||||
return false, false, "", false, errHTTPBadRequest
|
return false, false, "", errHTTPBadRequestDelayTooLarge
|
||||||
}
|
}
|
||||||
m.Time = delay.Unix()
|
m.Time = delay.Unix()
|
||||||
}
|
}
|
||||||
return cache, firebase, email, unifiedpush, nil
|
return cache, firebase, email, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func readParam(r *http.Request, names ...string) string {
|
func readParam(r *http.Request, names ...string) string {
|
||||||
|
@ -470,8 +494,8 @@ func (s *Server) handleSubscribeRaw(w http.ResponseWriter, r *http.Request, v *v
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error {
|
func (s *Server) handleSubscribe(w http.ResponseWriter, r *http.Request, v *visitor, format string, contentType string, encoder messageEncoder) error {
|
||||||
if err := v.AddSubscription(); err != nil {
|
if err := v.SubscriptionAllowed(); err != nil {
|
||||||
return errHTTPTooManyRequests
|
return errHTTPTooManyRequestsLimitSubscriptions
|
||||||
}
|
}
|
||||||
defer v.RemoveSubscription()
|
defer v.RemoveSubscription()
|
||||||
topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack
|
topicsStr := strings.TrimSuffix(r.URL.Path[1:], "/"+format) // Hack
|
||||||
|
@ -617,7 +641,7 @@ func parseSince(r *http.Request, poll bool) (sinceTime, error) {
|
||||||
} else if d, err := time.ParseDuration(since); err == nil {
|
} else if d, err := time.ParseDuration(since); err == nil {
|
||||||
return sinceTime(time.Now().Add(-1 * d)), nil
|
return sinceTime(time.Now().Add(-1 * d)), nil
|
||||||
}
|
}
|
||||||
return sinceNoMessages, errHTTPBadRequest
|
return sinceNoMessages, errHTTPBadRequestSinceInvalid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error {
|
func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error {
|
||||||
|
@ -629,7 +653,7 @@ func (s *Server) handleOptions(w http.ResponseWriter, _ *http.Request) error {
|
||||||
func (s *Server) topicFromPath(path string) (*topic, error) {
|
func (s *Server) topicFromPath(path string) (*topic, error) {
|
||||||
parts := strings.Split(path, "/")
|
parts := strings.Split(path, "/")
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return nil, errHTTPBadRequest
|
return nil, errHTTPBadRequestTopicInvalid
|
||||||
}
|
}
|
||||||
topics, err := s.topicsFromIDs(parts[1])
|
topics, err := s.topicsFromIDs(parts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -644,11 +668,11 @@ func (s *Server) topicsFromIDs(ids ...string) ([]*topic, error) {
|
||||||
topics := make([]*topic, 0)
|
topics := make([]*topic, 0)
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
if util.InStringList(disallowedTopics, id) {
|
if util.InStringList(disallowedTopics, id) {
|
||||||
return nil, errHTTPBadRequest
|
return nil, errHTTPBadRequestTopicDisallowed
|
||||||
}
|
}
|
||||||
if _, ok := s.topics[id]; !ok {
|
if _, ok := s.topics[id]; !ok {
|
||||||
if len(s.topics) >= s.config.GlobalTopicLimit {
|
if len(s.topics) >= s.config.GlobalTopicLimit {
|
||||||
return nil, errHTTPTooManyRequests
|
return nil, errHTTPTooManyRequestsLimitGlobalTopics
|
||||||
}
|
}
|
||||||
s.topics[id] = newTopic(id)
|
s.topics[id] = newTopic(id)
|
||||||
}
|
}
|
||||||
|
@ -766,7 +790,7 @@ func (s *Server) sendDelayedMessages() error {
|
||||||
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
|
func (s *Server) withRateLimit(w http.ResponseWriter, r *http.Request, handler func(w http.ResponseWriter, r *http.Request, v *visitor) error) error {
|
||||||
v := s.visitor(r)
|
v := s.visitor(r)
|
||||||
if err := v.RequestAllowed(); err != nil {
|
if err := v.RequestAllowed(); err != nil {
|
||||||
return err
|
return errHTTPTooManyRequestsLimitRequests
|
||||||
}
|
}
|
||||||
return handler(w, r, v)
|
return handler(w, r, v)
|
||||||
}
|
}
|
||||||
|
@ -798,9 +822,3 @@ func (s *Server) inc(counter *int64) {
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
*counter++
|
*counter++
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) fail(w http.ResponseWriter, r *http.Request, code int, err error) {
|
|
||||||
log.Printf("[%s] %s - %d - %s", r.RemoteAddr, r.Method, code, err.Error())
|
|
||||||
w.WriteHeader(code)
|
|
||||||
_, _ = io.WriteString(w, fmt.Sprintf("%s\n", http.StatusText(code)))
|
|
||||||
}
|
|
||||||
|
|
|
@ -252,6 +252,7 @@ func TestServer_PublishAtWithCacheError(t *testing.T) {
|
||||||
"In": "30 min",
|
"In": "30 min",
|
||||||
})
|
})
|
||||||
require.Equal(t, 400, response.Code)
|
require.Equal(t, 400, response.Code)
|
||||||
|
require.Equal(t, errHTTPBadRequestDelayNoCache, toHTTPError(t, response.Body.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServer_PublishAtTooShortDelay(t *testing.T) {
|
func TestServer_PublishAtTooShortDelay(t *testing.T) {
|
||||||
|
@ -644,6 +645,12 @@ func toMessage(t *testing.T, s string) *message {
|
||||||
return &m
|
return &m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toHTTPError(t *testing.T, s string) *errHTTP {
|
||||||
|
var e errHTTP
|
||||||
|
require.Nil(t, json.NewDecoder(strings.NewReader(s)).Decode(&e))
|
||||||
|
return &e
|
||||||
|
}
|
||||||
|
|
||||||
func firebaseServiceAccountFile(t *testing.T) string {
|
func firebaseServiceAccountFile(t *testing.T) string {
|
||||||
if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" {
|
if os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE") != "" {
|
||||||
return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE")
|
return os.Getenv("NTFY_TEST_FIREBASE_SERVICE_ACCOUNT_FILE")
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"golang.org/x/time/rate"
|
"golang.org/x/time/rate"
|
||||||
"heckel.io/ntfy/util"
|
"heckel.io/ntfy/util"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -14,6 +15,10 @@ const (
|
||||||
visitorExpungeAfter = 24 * time.Hour
|
visitorExpungeAfter = 24 * time.Hour
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errVisitorLimitReached = errors.New("limit reached")
|
||||||
|
)
|
||||||
|
|
||||||
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
// visitor represents an API user, and its associated rate.Limiter used for rate limiting
|
||||||
type visitor struct {
|
type visitor struct {
|
||||||
config *Config
|
config *Config
|
||||||
|
@ -42,23 +47,23 @@ func (v *visitor) IP() string {
|
||||||
|
|
||||||
func (v *visitor) RequestAllowed() error {
|
func (v *visitor) RequestAllowed() error {
|
||||||
if !v.requests.Allow() {
|
if !v.requests.Allow() {
|
||||||
return errHTTPTooManyRequests
|
return errVisitorLimitReached
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) EmailAllowed() error {
|
func (v *visitor) EmailAllowed() error {
|
||||||
if !v.emails.Allow() {
|
if !v.emails.Allow() {
|
||||||
return errHTTPTooManyRequests
|
return errVisitorLimitReached
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *visitor) AddSubscription() error {
|
func (v *visitor) SubscriptionAllowed() error {
|
||||||
v.mu.Lock()
|
v.mu.Lock()
|
||||||
defer v.mu.Unlock()
|
defer v.mu.Unlock()
|
||||||
if err := v.subscriptions.Add(1); err != nil {
|
if err := v.subscriptions.Add(1); err != nil {
|
||||||
return errHTTPTooManyRequests
|
return errVisitorLimitReached
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue