diff --git a/server/message.go b/server/message.go index 1afa34f..ad870e0 100644 --- a/server/message.go +++ b/server/message.go @@ -1,9 +1,8 @@ package server import ( - "time" - "heckel.io/ntfy/util" + "time" ) // List of possible events @@ -19,15 +18,14 @@ const ( // message represents a message published to a topic type message struct { - ID string `json:"id"` // Random message ID - Time int64 `json:"time"` // Unix time in seconds - Event string `json:"event"` // One of the above - Topic string `json:"topic"` - Priority int `json:"priority,omitempty"` - Tags []string `json:"tags,omitempty"` - Title string `json:"title,omitempty"` - Message string `json:"message,omitempty"` - UnifiedPush bool `json:"unifiedpush,omitempty"` //this could be 'up' + ID string `json:"id"` // Random message ID + Time int64 `json:"time"` // Unix time in seconds + Event string `json:"event"` // One of the above + Topic string `json:"topic"` + Priority int `json:"priority,omitempty"` + Tags []string `json:"tags,omitempty"` + Title string `json:"title,omitempty"` + Message string `json:"message,omitempty"` } // messageEncoder is a function that knows how to encode a message diff --git a/server/server.go b/server/server.go index 7ef5939..78715d2 100644 --- a/server/server.go +++ b/server/server.go @@ -5,7 +5,11 @@ import ( "context" "embed" "encoding/json" + firebase "firebase.google.com/go" + "firebase.google.com/go/messaging" "fmt" + "google.golang.org/api/option" + "heckel.io/ntfy/util" "html/template" "io" "log" @@ -16,11 +20,6 @@ import ( "strings" "sync" "time" - - firebase "firebase.google.com/go" - "firebase.google.com/go/messaging" - "google.golang.org/api/option" - "heckel.io/ntfy/util" ) // TODO add "max messages in a topic" limit @@ -288,7 +287,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { } else if r.Method == http.MethodOptions { return s.handleOptions(w, r) } else if r.Method == http.MethodGet && topicRegex.MatchString(r.URL.Path) { - return s.handleHome(w, r) + return s.handleTopic(w, r) } else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicRegex.MatchString(r.URL.Path) { return s.withRateLimit(w, r, s.handlePublish) } else if r.Method == http.MethodGet && sendRegex.MatchString(r.URL.Path) { @@ -310,6 +309,17 @@ func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) error { }) } +func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error { + unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see PUT/POST too! + if unifiedpush { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`+"\n") + return err + } + return s.handleHome(w, r) +} + func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request) error { return nil } @@ -340,25 +350,15 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito return err } m := newDefaultMessage(t.ID, strings.TrimSpace(string(b))) - cache, firebase, email, err := s.parseParams(r, m) + cache, firebase, email, err := s.parsePublishParams(r, m) if err != nil { return err } - - if r.Method == http.MethodGet && unifiedpush { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests - _, err := io.WriteString(w, `{"unifiedpush":{"version":1}}`) - return err - } - if email != "" { if err := v.EmailAllowed(); err != nil { return errHTTPTooManyRequestsLimitEmails } } - - m.UnifiedPush = unifiedpush if s.mailer == nil && email != "" { return errHTTPBadRequestEmailDisabled } @@ -371,21 +371,20 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito return err } } - if s.firebase != nil && firebase && !delayed && !unifiedpush { + if s.firebase != nil && firebase && !delayed { go func() { if err := s.firebase(m); err != nil { log.Printf("Unable to publish to Firebase: %v", err.Error()) } }() } - if s.mailer != nil && email != "" && !delayed && !unifiedpush { + if s.mailer != nil && email != "" && !delayed { go func() { if err := s.mailer.Send(v.ip, email, m); err != nil { log.Printf("Unable to send email: %v", err.Error()) } }() } - if cache { if err := s.cache.AddMessage(m); err != nil { return err @@ -393,7 +392,6 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito } w.Header().Set("Content-Type", "application/json") w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests - if err := json.NewEncoder(w).Encode(m); err != nil { return err } @@ -401,7 +399,7 @@ func (s *Server) handlePublish(w http.ResponseWriter, r *http.Request, v *visito return nil } -func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) { +func (s *Server) parsePublishParams(r *http.Request, m *message) (cache bool, firebase bool, email string, err error) { cache = readParam(r, "x-cache", "cache") != "no" firebase = readParam(r, "x-firebase", "firebase") != "no" email = readParam(r, "x-email", "x-e-mail", "email", "e-mail", "mail", "e") @@ -439,6 +437,10 @@ func (s *Server) parseParams(r *http.Request, m *message) (cache bool, firebase } m.Time = delay.Unix() } + unifiedpush := readParam(r, "x-unifiedpush", "unifiedpush", "up") == "1" // see GET too! + if unifiedpush { + firebase = false + } return cache, firebase, email, nil } diff --git a/server/server_test.go b/server/server_test.go index cf3fc58..589b982 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -583,6 +583,13 @@ func TestServer_PublishEmailNoMailer_Fail(t *testing.T) { require.Equal(t, 400, response.Code) } +func TestServer_UnifiedPushDiscovery(t *testing.T) { + s := newTestServer(t, newTestConfig(t)) + response := request(t, s, "GET", "/mytopic?up=1", "", nil) + require.Equal(t, 200, response.Code) + require.Equal(t, `{"unifiedpush":{"version":1}}`+"\n", response.Body.String()) +} + func newTestConfig(t *testing.T) *Config { conf := NewConfig() conf.CacheFile = filepath.Join(t.TempDir(), "cache.db")