diff --git a/server/server.go b/server/server.go index 6f53794..e77d149 100644 --- a/server/server.go +++ b/server/server.go @@ -69,6 +69,7 @@ var ( ssePathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/sse$`) rawPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/raw$`) wsPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/ws$`) + authPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/auth$`) publishPathRegex = regexp.MustCompile(`^/[-_A-Za-z0-9]{1,64}(,[-_A-Za-z0-9]{1,64})*/(publish|send|trigger)$`) staticRegex = regexp.MustCompile(`^/static/.+`) @@ -331,7 +332,7 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { } else if r.Method == http.MethodGet && r.URL.Path == "/example.html" { return s.handleExample(w, r) } else if r.Method == http.MethodHead && r.URL.Path == "/" { - return s.handleEmpty(w, r) + return s.handleEmpty(w, r, v) } else if r.Method == http.MethodGet && staticRegex.MatchString(r.URL.Path) { return s.handleStatic(w, r) } else if r.Method == http.MethodGet && docsRegex.MatchString(r.URL.Path) { @@ -354,6 +355,8 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request) error { return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v) } else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) { return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v) + } else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) { + return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v) } return errHTTPNotFound } @@ -376,10 +379,17 @@ func (s *Server) handleTopic(w http.ResponseWriter, r *http.Request) error { return s.handleHome(w, r) } -func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request) error { +func (s *Server) handleEmpty(_ http.ResponseWriter, _ *http.Request, _ *visitor) error { return nil } +func (s *Server) handleTopicAuth(w http.ResponseWriter, _ *http.Request, _ *visitor) error { + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Access-Control-Allow-Origin", "*") // CORS, allow cross-origin requests + _, err := io.WriteString(w, `{"success":true}`+"\n") + return err +} + func (s *Server) handleExample(w http.ResponseWriter, _ *http.Request) error { _, err := io.WriteString(w, exampleSource) return err diff --git a/server/server_test.go b/server/server_test.go index 061e6e8..990be39 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "github.com/stretchr/testify/require" + "heckel.io/ntfy/auth" "heckel.io/ntfy/util" "math/rand" "net/http" @@ -524,6 +525,104 @@ func TestServer_SubscribeWithQueryFilters(t *testing.T) { require.Equal(t, keepaliveEvent, messages[2].Event) } +func TestServer_Auth_Success_Admin(t *testing.T) { + c := newTestConfig(t) + c.AuthFile = filepath.Join(t.TempDir(), "user.db") + s := newTestServer(t, c) + + manager := s.auth.(auth.Manager) + require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) + + response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ + "Authorization": basicAuth("phil:phil"), + }) + require.Equal(t, 200, response.Code) + require.Equal(t, `{"success":true}`+"\n", response.Body.String()) +} + +func TestServer_Auth_Success_User(t *testing.T) { + c := newTestConfig(t) + c.AuthFile = filepath.Join(t.TempDir(), "user.db") + c.AuthDefaultRead = false + c.AuthDefaultWrite = false + s := newTestServer(t, c) + + manager := s.auth.(auth.Manager) + require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser)) + require.Nil(t, manager.AllowAccess("ben", "mytopic", true, true)) // Not mytopic! + + response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ + "Authorization": basicAuth("ben:ben"), + }) + require.Equal(t, 200, response.Code) +} + +func TestServer_Auth_Fail_InvalidPass(t *testing.T) { + c := newTestConfig(t) + c.AuthFile = filepath.Join(t.TempDir(), "user.db") + c.AuthDefaultRead = false + c.AuthDefaultWrite = false + s := newTestServer(t, c) + + manager := s.auth.(auth.Manager) + require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) + + response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ + "Authorization": basicAuth("phil:INVALID"), + }) + require.Equal(t, 401, response.Code) +} + +func TestServer_Auth_Fail_Unauthorized(t *testing.T) { + c := newTestConfig(t) + c.AuthFile = filepath.Join(t.TempDir(), "user.db") + c.AuthDefaultRead = false + c.AuthDefaultWrite = false + s := newTestServer(t, c) + + manager := s.auth.(auth.Manager) + require.Nil(t, manager.AddUser("ben", "ben", auth.RoleUser)) + require.Nil(t, manager.AllowAccess("ben", "sometopic", true, true)) // Not mytopic! + + response := request(t, s, "GET", "/mytopic/auth", "", map[string]string{ + "Authorization": basicAuth("ben:ben"), + }) + require.Equal(t, 403, response.Code) +} + +func TestServer_Auth_Fail_CannotPublish(t *testing.T) { + c := newTestConfig(t) + c.AuthFile = filepath.Join(t.TempDir(), "user.db") + c.AuthDefaultRead = true // Open by default + c.AuthDefaultWrite = true // Open by default + s := newTestServer(t, c) + + manager := s.auth.(auth.Manager) + require.Nil(t, manager.AddUser("phil", "phil", auth.RoleAdmin)) + require.Nil(t, manager.AllowAccess(auth.Everyone, "private", false, false)) + require.Nil(t, manager.AllowAccess(auth.Everyone, "announcements", true, false)) + + response := request(t, s, "PUT", "/mytopic", "test", nil) + require.Equal(t, 200, response.Code) + + response = request(t, s, "GET", "/mytopic/json?poll=1", "", nil) + require.Equal(t, 200, response.Code) + + response = request(t, s, "PUT", "/announcements", "test", nil) + require.Equal(t, 403, response.Code) // Cannot write as anonymous + + response = request(t, s, "PUT", "/announcements", "test", map[string]string{ + "Authorization": basicAuth("phil:phil"), + }) + require.Equal(t, 200, response.Code) + + response = request(t, s, "GET", "/announcements/json?poll=1", "", nil) + require.Equal(t, 200, response.Code) // Anonymous read allowed + + response = request(t, s, "GET", "/private/json?poll=1", "", nil) + require.Equal(t, 403, response.Code) // Anonymous read not allowed +} + /* func TestServer_Curl_Publish_Poll(t *testing.T) { s, port := test.StartServer(t) @@ -988,3 +1087,7 @@ func firebaseServiceAccountFile(t *testing.T) string { t.SkipNow() return "" } + +func basicAuth(s string) string { + return fmt.Sprintf("Basic %s", base64.StdEncoding.EncodeToString([]byte(s))) +}