Token stuff
This commit is contained in:
parent
d3dfeeccc3
commit
d499d20a9c
8 changed files with 194 additions and 64 deletions
15
auth/auth.go
15
auth/auth.go
|
@ -6,13 +6,17 @@ import (
|
|||
"regexp"
|
||||
)
|
||||
|
||||
// Auther is a generic interface to implement password-based authentication and authorization
|
||||
// Auther is a generic interface to implement password and token based authentication and authorization
|
||||
type Auther interface {
|
||||
// Authenticate checks username and password and returns a user if correct. The method
|
||||
// returns in constant-ish time, regardless of whether the user exists or the password is
|
||||
// correct or incorrect.
|
||||
Authenticate(username, password string) (*User, error)
|
||||
|
||||
AuthenticateToken(token string) (*User, error)
|
||||
|
||||
GenerateToken(user *User) (string, error)
|
||||
|
||||
// Authorize returns nil if the given user has access to the given topic using the desired
|
||||
// permission. The user param may be nil to signal an anonymous user.
|
||||
Authorize(user *User, topic string, perm Permission) error
|
||||
|
@ -56,10 +60,11 @@ type Manager interface {
|
|||
|
||||
// User is a struct that represents a user
|
||||
type User struct {
|
||||
Name string
|
||||
Hash string // password hash (bcrypt)
|
||||
Role Role
|
||||
Grants []Grant
|
||||
Name string
|
||||
Hash string // password hash (bcrypt)
|
||||
Role Role
|
||||
Grants []Grant
|
||||
Language string
|
||||
}
|
||||
|
||||
// Grant is a struct that represents an access control entry to a topic
|
||||
|
|
|
@ -6,10 +6,12 @@ import (
|
|||
"fmt"
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite driver
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"heckel.io/ntfy/util"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
tokenLength = 32
|
||||
bcryptCost = 10
|
||||
intentionalSlowDownHash = "$2a$10$YFCQvqQDwIIwnJM1xkAYOeih0dg17UVGanaTStnrSzC8NCWxcLDwy" // Cost should match bcryptCost
|
||||
)
|
||||
|
@ -67,7 +69,17 @@ const (
|
|||
INSERT INTO user (id, user, pass, role) VALUES (1, '*', '', 'anonymous') ON CONFLICT (id) DO NOTHING;
|
||||
COMMIT;
|
||||
`
|
||||
selectUserQuery = `SELECT pass, role FROM user WHERE user = ?`
|
||||
selectUserByNameQuery = `
|
||||
SELECT user, pass, role, language
|
||||
FROM user
|
||||
WHERE user = ?
|
||||
`
|
||||
selectUserByTokenQuery = `
|
||||
SELECT user, pass, role, language
|
||||
FROM user
|
||||
JOIN user_token on user.id = user_token.user_id
|
||||
WHERE token = ?
|
||||
`
|
||||
selectTopicPermsQuery = `
|
||||
SELECT read, write
|
||||
FROM user_access
|
||||
|
@ -90,6 +102,8 @@ const (
|
|||
deleteAllAccessQuery = `DELETE FROM user_access`
|
||||
deleteUserAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?)`
|
||||
deleteTopicAccessQuery = `DELETE FROM user_access WHERE user_id = (SELECT id FROM user WHERE user = ?) AND topic = ?`
|
||||
|
||||
insertTokenQuery = `INSERT INTO user_token (user_id, token, expires) VALUES ((SELECT id FROM user WHERE user = ?), ?, ?)`
|
||||
)
|
||||
|
||||
// Schema management queries
|
||||
|
@ -126,7 +140,7 @@ func NewSQLiteAuth(filename string, defaultRead, defaultWrite bool) (*SQLiteAuth
|
|||
}, nil
|
||||
}
|
||||
|
||||
// Authenticate checks username and password and returns a user if correct. The method
|
||||
// AuthenticateUser checks username and password and returns a user if correct. The method
|
||||
// returns in constant-ish time, regardless of whether the user exists or the password is
|
||||
// correct or incorrect.
|
||||
func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
|
||||
|
@ -145,6 +159,23 @@ func (a *SQLiteAuth) Authenticate(username, password string) (*User, error) {
|
|||
return user, nil
|
||||
}
|
||||
|
||||
func (a *SQLiteAuth) AuthenticateToken(token string) (*User, error) {
|
||||
user, err := a.userByToken(token)
|
||||
if err != nil {
|
||||
return nil, ErrUnauthenticated
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (a *SQLiteAuth) GenerateToken(user *User) (string, error) {
|
||||
token := util.RandomString(tokenLength)
|
||||
expires := 1 // FIXME
|
||||
if _, err := a.db.Exec(insertTokenQuery, user.Name, token, expires); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Authorize returns nil if the given user has access to the given topic using the desired
|
||||
// permission. The user param may be nil to signal an anonymous user.
|
||||
func (a *SQLiteAuth) Authorize(user *User, topic string, perm Permission) error {
|
||||
|
@ -255,16 +286,29 @@ func (a *SQLiteAuth) User(username string) (*User, error) {
|
|||
if username == Everyone {
|
||||
return a.everyoneUser()
|
||||
}
|
||||
rows, err := a.db.Query(selectUserQuery, username)
|
||||
rows, err := a.db.Query(selectUserByNameQuery, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.readUser(rows)
|
||||
}
|
||||
|
||||
func (a *SQLiteAuth) userByToken(token string) (*User, error) {
|
||||
rows, err := a.db.Query(selectUserByTokenQuery, token)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a.readUser(rows)
|
||||
}
|
||||
|
||||
func (a *SQLiteAuth) readUser(rows *sql.Rows) (*User, error) {
|
||||
defer rows.Close()
|
||||
var hash, role string
|
||||
var username, hash, role string
|
||||
var language sql.NullString
|
||||
if !rows.Next() {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err := rows.Scan(&hash, &role); err != nil {
|
||||
if err := rows.Scan(&username, &hash, &role, &language); err != nil {
|
||||
return nil, err
|
||||
} else if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
|
@ -274,10 +318,11 @@ func (a *SQLiteAuth) User(username string) (*User, error) {
|
|||
return nil, err
|
||||
}
|
||||
return &User{
|
||||
Name: username,
|
||||
Hash: hash,
|
||||
Role: Role(role),
|
||||
Grants: grants,
|
||||
Name: username,
|
||||
Hash: hash,
|
||||
Role: Role(role),
|
||||
Grants: grants,
|
||||
Language: language.String,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
128
server/server.go
128
server/server.go
|
@ -7,6 +7,7 @@ import (
|
|||
"embed"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
@ -320,23 +321,23 @@ func (s *Server) handleInternal(w http.ResponseWriter, r *http.Request, v *visit
|
|||
} else if r.Method == http.MethodOptions {
|
||||
return s.ensureWebEnabled(s.handleOptions)(w, r, v)
|
||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && r.URL.Path == "/" {
|
||||
return s.limitRequests(s.transformBodyJSON(s.authWrite(s.handlePublish)))(w, r, v)
|
||||
return s.limitRequests(s.transformBodyJSON(s.authorizeTopicWrite(s.handlePublish)))(w, r, v)
|
||||
} else if r.Method == http.MethodPost && r.URL.Path == matrixPushPath {
|
||||
return s.limitRequests(s.transformMatrixJSON(s.authWrite(s.handlePublishMatrix)))(w, r, v)
|
||||
return s.limitRequests(s.transformMatrixJSON(s.authorizeTopicWrite(s.handlePublishMatrix)))(w, r, v)
|
||||
} else if (r.Method == http.MethodPut || r.Method == http.MethodPost) && topicPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && publishPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authWrite(s.handlePublish))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicWrite(s.handlePublish))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && jsonPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeJSON))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeJSON))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && ssePathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeSSE))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeSSE))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && rawPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeRaw))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeRaw))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && wsPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleSubscribeWS))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleSubscribeWS))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && authPathRegex.MatchString(r.URL.Path) {
|
||||
return s.limitRequests(s.authRead(s.handleTopicAuth))(w, r, v)
|
||||
return s.limitRequests(s.authorizeTopicRead(s.handleTopicAuth))(w, r, v)
|
||||
} else if r.Method == http.MethodGet && (topicPathRegex.MatchString(r.URL.Path) || externalTopicPathRegex.MatchString(r.URL.Path)) {
|
||||
return s.ensureWebEnabled(s.handleTopic)(w, r, v)
|
||||
}
|
||||
|
@ -403,8 +404,6 @@ func (s *Server) handleUserStats(w http.ResponseWriter, r *http.Request, v *visi
|
|||
return nil
|
||||
}
|
||||
|
||||
var sessions = make(map[string]*auth.User) // token-> user
|
||||
|
||||
type tokenAuthResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
@ -414,8 +413,10 @@ func (s *Server) handleUserAuth(w http.ResponseWriter, r *http.Request, v *visit
|
|||
if v.user == nil {
|
||||
return errHTTPUnauthorized
|
||||
}
|
||||
token := util.RandomString(32)
|
||||
sessions[token] = v.user
|
||||
token, err := s.auth.GenerateToken(v.user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
response := &tokenAuthResponse{
|
||||
|
@ -432,35 +433,41 @@ type userSubscriptionResponse struct {
|
|||
Topic string `json:"topic"`
|
||||
}
|
||||
|
||||
type userNotificationSettingsResponse struct {
|
||||
Sound string `json:"sound"`
|
||||
MinPriority string `json:"min_priority"`
|
||||
DeleteAfter int `json:"delete_after"`
|
||||
}
|
||||
|
||||
type userPlanResponse struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type userAccountResponse struct {
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Plan struct {
|
||||
Id int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"plan,omitempty"`
|
||||
Notification struct {
|
||||
Sound string `json:"sound"`
|
||||
MinPriority string `json:"min_priority"`
|
||||
DeleteAfter int `json:"delete_after"`
|
||||
} `json:"notification,omitempty"`
|
||||
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Language string `json:"language,omitempty"`
|
||||
Plan *userPlanResponse `json:"plan,omitempty"`
|
||||
Notification *userNotificationSettingsResponse `json:"notification,omitempty"`
|
||||
Subscriptions []*userSubscriptionResponse `json:"subscriptions,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Server) handleUserAccount(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
w.Header().Set("Content-Type", "text/json")
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*") // FIXME remove this
|
||||
var response *userAccountResponse
|
||||
response := &userAccountResponse{}
|
||||
if v.user != nil {
|
||||
response = &userAccountResponse{
|
||||
Username: v.user.Name,
|
||||
Role: string(v.user.Role),
|
||||
Language: "en_US",
|
||||
response.Username = v.user.Name
|
||||
response.Role = string(v.user.Role)
|
||||
response.Language = v.user.Language
|
||||
response.Notification = &userNotificationSettingsResponse{
|
||||
Sound: "dadum",
|
||||
}
|
||||
} else {
|
||||
response = &userAccountResponse{
|
||||
Username: "anonymous",
|
||||
Username: auth.Everyone,
|
||||
Role: string(auth.RoleAnonymous),
|
||||
}
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(response); err != nil {
|
||||
|
@ -1453,15 +1460,15 @@ func (s *Server) transformMatrixJSON(next handleFunc) handleFunc {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) authWrite(next handleFunc) handleFunc {
|
||||
return s.withAuth(next, auth.PermissionWrite)
|
||||
func (s *Server) authorizeTopicWrite(next handleFunc) handleFunc {
|
||||
return s.autorizeTopic(next, auth.PermissionWrite)
|
||||
}
|
||||
|
||||
func (s *Server) authRead(next handleFunc) handleFunc {
|
||||
return s.withAuth(next, auth.PermissionRead)
|
||||
func (s *Server) authorizeTopicRead(next handleFunc) handleFunc {
|
||||
return s.autorizeTopic(next, auth.PermissionRead)
|
||||
}
|
||||
|
||||
func (s *Server) withAuth(next handleFunc, perm auth.Permission) handleFunc {
|
||||
func (s *Server) autorizeTopic(next handleFunc, perm auth.Permission) handleFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request, v *visitor) error {
|
||||
if s.auth == nil {
|
||||
return next(w, r, v)
|
||||
|
@ -1508,20 +1515,51 @@ func (s *Server) visitor(r *http.Request) (v *visitor, err error) {
|
|||
visitorID := fmt.Sprintf("ip:%s", ip.String())
|
||||
|
||||
var user *auth.User // may stay nil if no auth header!
|
||||
username, password, ok := extractUserPass(r)
|
||||
if ok {
|
||||
if user, err = s.auth.Authenticate(username, password); err != nil {
|
||||
log.Debug("authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
} else {
|
||||
visitorID = fmt.Sprintf("user:%s", user.Name)
|
||||
}
|
||||
if user, err = s.authenticate(r); err != nil {
|
||||
log.Debug("authentication failed: %s", err.Error())
|
||||
err = errHTTPUnauthorized // Always return visitor, even when error occurs!
|
||||
}
|
||||
if user != nil {
|
||||
visitorID = fmt.Sprintf("user:%s", user.Name)
|
||||
}
|
||||
v = s.visitorFromID(visitorID, ip, user)
|
||||
v.user = user // Update user -- FIXME this is ugly, do "newVisitorFromUser" instead
|
||||
return v, err // Always return visitor, even when error occurs!
|
||||
}
|
||||
|
||||
func (s *Server) authenticate(r *http.Request) (user *auth.User, err error) {
|
||||
value := r.Header.Get("Authorization")
|
||||
queryParam := readQueryParam(r, "authorization", "auth")
|
||||
if queryParam != "" {
|
||||
a, err := base64.RawURLEncoding.DecodeString(queryParam)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
value = string(a)
|
||||
}
|
||||
if value == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if strings.HasPrefix(value, "Bearer") {
|
||||
return s.authenticateBearerAuth(value)
|
||||
}
|
||||
return s.authenticateBasicAuth(r, value)
|
||||
}
|
||||
|
||||
func (s *Server) authenticateBasicAuth(r *http.Request, value string) (user *auth.User, err error) {
|
||||
r.Header.Set("Authorization", value)
|
||||
username, password, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
return nil, errors.New("invalid basic auth")
|
||||
}
|
||||
return s.auth.Authenticate(username, password)
|
||||
}
|
||||
|
||||
func (s *Server) authenticateBearerAuth(value string) (user *auth.User, err error) {
|
||||
token := strings.TrimSpace(strings.TrimPrefix(value, "Bearer"))
|
||||
return s.auth.AuthenticateToken(token)
|
||||
}
|
||||
|
||||
func (s *Server) visitorFromID(visitorID string, ip netip.Addr, user *auth.User) *visitor {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
|
|
@ -18,7 +18,7 @@ type testAuther struct {
|
|||
Allow bool
|
||||
}
|
||||
|
||||
func (t testAuther) Authenticate(_, _ string) (*auth.User, error) {
|
||||
func (t testAuther) AuthenticateUser(_, _ string) (*auth.User, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
import {
|
||||
fetchLinesIterator,
|
||||
maybeWithBasicAuth,
|
||||
maybeWithBasicAuth, maybeWithBearerAuth,
|
||||
topicShortUrl,
|
||||
topicUrl,
|
||||
topicUrlAuth,
|
||||
topicUrlJsonPoll,
|
||||
topicUrlJsonPollWithSince, userAuthUrl,
|
||||
topicUrlJsonPollWithSince,
|
||||
userAccountUrl,
|
||||
userAuthUrl,
|
||||
userStatsUrl
|
||||
} from "./utils";
|
||||
import userManager from "./UserManager";
|
||||
|
@ -144,6 +146,20 @@ class Api {
|
|||
console.log(`[Api] Stats`, stats);
|
||||
return stats;
|
||||
}
|
||||
|
||||
async userAccount(baseUrl, token) {
|
||||
const url = userAccountUrl(baseUrl);
|
||||
console.log(`[Api] Fetching user account ${url}`);
|
||||
const response = await fetch(url, {
|
||||
headers: maybeWithBearerAuth({}, token)
|
||||
});
|
||||
if (response.status !== 200) {
|
||||
throw new Error(`Unexpected server response ${response.status}`);
|
||||
}
|
||||
const account = await response.json();
|
||||
console.log(`[Api] Account`, account);
|
||||
return account;
|
||||
}
|
||||
}
|
||||
|
||||
const api = new Api();
|
||||
|
|
|
@ -20,6 +20,7 @@ export const topicUrlAuth = (baseUrl, topic) => `${topicUrl(baseUrl, topic)}/aut
|
|||
export const topicShortUrl = (baseUrl, topic) => shortUrl(topicUrl(baseUrl, topic));
|
||||
export const userStatsUrl = (baseUrl) => `${baseUrl}/user/stats`;
|
||||
export const userAuthUrl = (baseUrl) => `${baseUrl}/user/auth`;
|
||||
export const userAccountUrl = (baseUrl) => `${baseUrl}/user/account`;
|
||||
export const shortUrl = (url) => url.replaceAll(/https?:\/\//g, "");
|
||||
export const expandUrl = (url) => [`https://${url}`, `http://${url}`];
|
||||
export const expandSecureUrl = (url) => `https://${url}`;
|
||||
|
@ -95,7 +96,6 @@ export const unmatchedTags = (tags) => {
|
|||
else return tags.filter(tag => !(tag in emojis));
|
||||
}
|
||||
|
||||
|
||||
export const maybeWithBasicAuth = (headers, user) => {
|
||||
if (user) {
|
||||
headers['Authorization'] = `Basic ${encodeBase64(`${user.username}:${user.password}`)}`;
|
||||
|
@ -103,6 +103,13 @@ export const maybeWithBasicAuth = (headers, user) => {
|
|||
return headers;
|
||||
}
|
||||
|
||||
export const maybeWithBearerAuth = (headers, token) => {
|
||||
if (token) {
|
||||
headers['Authorization'] = `Bearer ${token}`;
|
||||
}
|
||||
return headers;
|
||||
}
|
||||
|
||||
export const basicAuth = (username, password) => {
|
||||
return `Basic ${encodeBase64(`${username}:${password}`)}`;
|
||||
}
|
||||
|
|
|
@ -25,6 +25,10 @@ import "./i18n"; // Translations!
|
|||
import {Backdrop, CircularProgress} from "@mui/material";
|
||||
import Home from "./Home";
|
||||
import Login from "./Login";
|
||||
import i18n from "i18next";
|
||||
import api from "../app/Api";
|
||||
import prefs from "../app/Prefs";
|
||||
import session from "../app/Session";
|
||||
|
||||
// TODO races when two tabs are open
|
||||
// TODO investigate service workers
|
||||
|
@ -81,6 +85,21 @@ const Layout = () => {
|
|||
useBackgroundProcesses();
|
||||
useEffect(() => updateTitle(newNotificationsCount), [newNotificationsCount]);
|
||||
|
||||
useEffect(() => {
|
||||
(async () => {
|
||||
const account = await api.userAccount("http://localhost:2586", session.token());
|
||||
if (account) {
|
||||
if (account.language) {
|
||||
await i18n.changeLanguage(account.language);
|
||||
}
|
||||
if (account.notification) {
|
||||
if (account.notification.sound) {
|
||||
await prefs.setSound(account.notification.sound);
|
||||
}
|
||||
}
|
||||
}
|
||||
})();
|
||||
});
|
||||
return (
|
||||
<Box sx={{display: 'flex'}}>
|
||||
<CssBaseline/>
|
||||
|
|
|
@ -32,7 +32,7 @@ const Login = () => {
|
|||
email: data.get('email'),
|
||||
password: data.get('password'),
|
||||
});
|
||||
const user ={
|
||||
const user = {
|
||||
username: data.get('email'),
|
||||
password: data.get('password'),
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue