Finish plugin API and add basic login system

This commit is contained in:
Tulir Asokan 2018-10-31 02:03:27 +02:00
parent d7f072aeff
commit 14fd0d6ac9
16 changed files with 160 additions and 62 deletions

1
.gitignore vendored
View file

@ -13,3 +13,4 @@ __pycache__
logs/ logs/
plugins/ plugins/
trash/

View file

@ -30,8 +30,10 @@ server:
# Set to "generate" to generate and save a new token at startup. # Set to "generate" to generate and save a new token at startup.
unshared_secret: generate unshared_secret: generate
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
# to prevent normal login. Root is a special user that can't have a password and will always exist.
admins: admins:
- "@admin:example.com" root: ""
# Python logging configuration. # Python logging configuration.
# #

View file

@ -14,7 +14,6 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from sqlalchemy import orm from sqlalchemy import orm
from time import time
import sqlalchemy as sql import sqlalchemy as sql
import logging.config import logging.config
import argparse import argparse
@ -22,7 +21,6 @@ import asyncio
import signal import signal
import copy import copy
import sys import sys
import os
from .config import Config from .config import Config
from .db import Base, init as init_db from .db import Base, init as init_db

View file

@ -15,9 +15,13 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
import random import random
import string import string
import bcrypt
import re
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
bcrypt_regex = re.compile(r"^\$2[ayb]\$.{56}$")
class Config(BaseFileConfig): class Config(BaseFileConfig):
@staticmethod @staticmethod
@ -27,16 +31,35 @@ class Config(BaseFileConfig):
def do_update(self, helper: ConfigUpdateHelper) -> None: def do_update(self, helper: ConfigUpdateHelper) -> None:
base, copy, _ = helper base, copy, _ = helper
copy("database") copy("database")
copy("plugin_directories") copy("plugin_directories.upload")
copy("plugin_db_directory") copy("plugin_directories.load")
copy("plugin_directories.trash")
copy("plugin_directories.db")
copy("server.hostname") copy("server.hostname")
copy("server.port") copy("server.port")
copy("server.listen") copy("server.listen")
copy("server.base_path") copy("server.appservice_base_path")
shared_secret = self["server.shared_secret"] shared_secret = self["server.unshared_secret"]
if shared_secret is None or shared_secret == "generate": if shared_secret is None or shared_secret == "generate":
base["server.shared_secret"] = self._new_token() base["server.unshared_secret"] = self._new_token()
else: else:
base["server.shared_secret"] = shared_secret base["server.unshared_secret"] = shared_secret
copy("admins") copy("admins")
for username, password in base["admins"].items():
if password and not bcrypt_regex.match(password):
if password == "password":
password = self._new_token()
base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"),
bcrypt.gensalt()).decode("utf-8")
copy("logging") copy("logging")
def is_admin(self, user: str) -> bool:
return user == "root" or user in self["admins"]
def check_password(self, user: str, passwd: str) -> bool:
if user == "root":
return False
passwd_hash = self["admins"].get(user, None)
if not passwd_hash:
return False
return bcrypt.checkpw(passwd.encode("utf-8"), passwd_hash.encode("utf-8"))

View file

@ -87,13 +87,6 @@ class PluginInstance:
def load_config(self) -> CommentedMap: def load_config(self) -> CommentedMap:
return yaml.load(self.db_instance.config) return yaml.load(self.db_instance.config)
def load_config_base(self) -> Optional[RecursiveDict[CommentedMap]]:
try:
base = self.loader.read_file("base-config.yaml")
return RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
except (FileNotFoundError, KeyError):
return None
def save_config(self, data: RecursiveDict[CommentedMap]) -> None: def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO() buf = io.StringIO()
yaml.dump(data, buf) yaml.dump(data, buf)
@ -103,14 +96,23 @@ class PluginInstance:
if not self.enabled: if not self.enabled:
self.log.warning(f"Plugin disabled, not starting.") self.log.warning(f"Plugin disabled, not starting.")
return return
cls = self.loader.load() cls = await self.loader.load()
config_class = cls.get_config_class() config_class = cls.get_config_class()
if config_class: if config_class:
self.config = config_class(self.load_config, self.load_config_base, try:
self.save_config) base = await self.loader.read_file("base-config.yaml")
base_file = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
except (FileNotFoundError, KeyError):
base_file = None
self.config = config_class(self.load_config, lambda: base_file, self.save_config)
self.plugin = cls(self.client.client, self.id, self.log, self.config, self.plugin = cls(self.client.client, self.id, self.log, self.config,
self.mb_config["plugin_db_directory"]) self.mb_config["plugin_directories.db"])
try:
await self.plugin.start() await self.plugin.start()
except Exception:
self.log.exception("Failed to start instance")
self.enabled = False
return
self.running = True self.running = True
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} " self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
f"with user {self.client.id}") f"with user {self.client.id}")

View file

@ -59,10 +59,12 @@ class PluginLoader(ABC):
pass pass
async def stop_instances(self) -> None: async def stop_instances(self) -> None:
await asyncio.gather([instance.stop() for instance in self.references if instance.running]) await asyncio.gather(*[instance.stop() for instance
in self.references if instance.running])
async def start_instances(self) -> None: async def start_instances(self) -> None:
await asyncio.gather([instance.start() for instance in self.references if instance.enabled]) await asyncio.gather(*[instance.start() for instance
in self.references if instance.enabled])
@abstractmethod @abstractmethod
async def load(self) -> Type[PluginClass]: async def load(self) -> Type[PluginClass]:

View file

@ -207,8 +207,10 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}") self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
return plugin return plugin
async def reload(self) -> Type[PluginClass]: async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]:
await self.unload() await self.unload()
if new_path is not None:
self.path = new_path
return await self.load(reset_cache=True) return await self.load(reset_cache=True)
async def unload(self) -> None: async def unload(self) -> None:

View file

@ -27,8 +27,9 @@ config: Config = None
def is_valid_token(token: str) -> bool: def is_valid_token(token: str) -> bool:
data = verify_token(config["server.unshared_secret"], token) data = verify_token(config["server.unshared_secret"], token)
user_id = data.get("user_id", None) if not data:
return user_id is not None and user_id in config["admins"] return False
return config.is_admin(data.get("user_id", None))
def create_token(user: UserID) -> str: def create_token(user: UserID) -> str:
@ -40,7 +41,9 @@ def create_token(user: UserID) -> str:
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
global config global config
config = cfg config = cfg
from .middleware import auth, error, log from .middleware import auth, error
app = web.Application(loop=loop, middlewares=[auth, log, error]) from .auth import web as _
from .plugin import web as _
app = web.Application(loop=loop, middlewares=[auth, error])
app.add_routes(routes) app.add_routes(routes)
return app return app

View file

@ -0,0 +1,43 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2018 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
import json
from . import routes, config, create_token
from .responses import ErrBadAuth, ErrBodyNotJSON
@routes.post("/login")
async def login(request: web.Request) -> web.Response:
try:
data = await request.json()
except json.JSONDecodeError:
return ErrBodyNotJSON
secret = data.get("secret")
if secret and config["server.unshared_secret"] == secret:
user = data.get("user") or "root"
return web.json_response({
"token": create_token(user),
})
username = data.get("username")
password = data.get("password")
if config.check_password(username, password):
return web.json_response({
"token": create_token(username),
})
return ErrBadAuth

View file

@ -15,25 +15,21 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Awaitable from typing import Callable, Awaitable
from aiohttp import web from aiohttp import web
import logging
from .responses import ErrNoToken, ErrInvalidToken from .responses import ErrNoToken, ErrInvalidToken, ErrPathNotFound, ErrMethodNotAllowed
from . import is_valid_token from . import is_valid_token
Handler = Callable[[web.Request], Awaitable[web.Response]] Handler = Callable[[web.Request], Awaitable[web.Response]]
req_log = logging.getLogger("maubot.mgmt.request")
resp_log = logging.getLogger("maubot.mgmt.response")
@web.middleware @web.middleware
async def auth(request: web.Request, handler: Handler) -> web.Response: async def auth(request: web.Request, handler: Handler) -> web.Response:
if request.path.endswith("/login"):
return await handler(request)
token = request.headers.get("Authorization", "") token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "): if not token or not token.startswith("Bearer "):
req_log.debug(f"Request missing auth: {request.remote} {request.method} {request.path}")
return ErrNoToken return ErrNoToken
if not is_valid_token(token[len("Bearer "):]): if not is_valid_token(token[len("Bearer "):]):
req_log.debug(f"Request invalid auth: {request.remote} {request.method} {request.path}")
return ErrInvalidToken return ErrInvalidToken
return await handler(request) return await handler(request)
@ -43,6 +39,10 @@ async def error(request: web.Request, handler: Handler) -> web.Response:
try: try:
return await handler(request) return await handler(request)
except web.HTTPException as ex: except web.HTTPException as ex:
if ex.status_code == 404:
return ErrPathNotFound
elif ex.status_code == 405:
return ErrMethodNotAllowed
return web.json_response({ return web.json_response({
"error": f"Unhandled HTTP {ex.status}", "error": f"Unhandled HTTP {ex.status}",
"errcode": f"unhandled_http_{ex.status}", "errcode": f"unhandled_http_{ex.status}",
@ -56,12 +56,3 @@ def get_req_no():
global req_no global req_no
req_no += 1 req_no += 1
return req_no return req_no
@web.middleware
async def log(request: web.Request, handler: Handler) -> web.Response:
local_req_no = get_req_no()
req_log.info(f"Request {local_req_no}: {request.remote} {request.method} {request.path}")
resp = await handler(request)
resp_log.info(f"Responded to {local_req_no} from {request.remote}: {resp}")
return resp

View file

@ -92,25 +92,24 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
if plugin.version in filename: if plugin.version in filename:
filename = filename.replace(plugin.version, new_version) filename = filename.replace(plugin.version, new_version)
else: else:
filename = filename.rstrip(".mbp") + new_version + ".mbp" filename = filename.rstrip(".mbp")
filename = f"{filename}-v{new_version}.mbp"
path = os.path.join(dirname, filename) path = os.path.join(dirname, filename)
with open(path, "wb") as p: with open(path, "wb") as p:
p.write(content) p.write(content)
old_path = plugin.path old_path = plugin.path
plugin.path = path
await plugin.stop_instances() await plugin.stop_instances()
try: try:
await plugin.reload() await plugin.reload(new_path=path)
except MaubotZipImportError as e: except MaubotZipImportError as e:
plugin.path = old_path
try: try:
await plugin.reload() await plugin.reload(new_path=old_path)
await plugin.start_instances()
except MaubotZipImportError: except MaubotZipImportError:
pass pass
await plugin.start_instances()
return plugin_import_error(str(e), traceback.format_exc()) return plugin_import_error(str(e), traceback.format_exc())
await plugin.start_instances() await plugin.start_instances()
ZippedPluginLoader.trash(plugin.path, reason="update") ZippedPluginLoader.trash(old_path, reason="update")
return RespOK return RespOK

View file

@ -13,27 +13,48 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from http import HTTPStatus
from aiohttp import web from aiohttp import web
ErrBadAuth = web.json_response({
"error": "Invalid username or password",
"errcode": "invalid_auth",
}, status=HTTPStatus.UNAUTHORIZED)
ErrNoToken = web.json_response({ ErrNoToken = web.json_response({
"error": "Authorization token missing", "error": "Authorization token missing",
"errcode": "auth_token_missing", "errcode": "auth_token_missing",
}, status=web.HTTPUnauthorized) }, status=HTTPStatus.UNAUTHORIZED)
ErrInvalidToken = web.json_response({ ErrInvalidToken = web.json_response({
"error": "Invalid authorization token", "error": "Invalid authorization token",
"errcode": "auth_token_invalid", "errcode": "auth_token_invalid",
}, status=web.HTTPUnauthorized) }, status=HTTPStatus.UNAUTHORIZED)
ErrPluginNotFound = web.json_response({ ErrPluginNotFound = web.json_response({
"error": "Plugin not found", "error": "Plugin not found",
"errcode": "plugin_not_found", "errcode": "plugin_not_found",
}, status=web.HTTPNotFound) }, status=HTTPStatus.NOT_FOUND)
ErrPathNotFound = web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",
}, status=HTTPStatus.NOT_FOUND)
ErrMethodNotAllowed = web.json_response({
"error": "Method not allowed",
"errcode": "method_not_allowed",
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
ErrPluginInUse = web.json_response({ ErrPluginInUse = web.json_response({
"error": "Plugin instances of this type still exist", "error": "Plugin instances of this type still exist",
"errcode": "plugin_in_use", "errcode": "plugin_in_use",
}, status=web.HTTPPreconditionFailed) }, status=HTTPStatus.PRECONDITION_FAILED)
ErrBodyNotJSON = web.json_response({
"error": "Request body is not JSON",
"errcode": "body_not_json",
}, status=HTTPStatus.BAD_REQUEST)
def plugin_import_error(error: str, stacktrace: str) -> web.Response: def plugin_import_error(error: str, stacktrace: str) -> web.Response:
@ -41,7 +62,7 @@ def plugin_import_error(error: str, stacktrace: str) -> web.Response:
"error": error, "error": error,
"stacktrace": stacktrace, "stacktrace": stacktrace,
"errcode": "plugin_invalid", "errcode": "plugin_invalid",
}, status=web.HTTPBadRequest) }, status=HTTPStatus.BAD_REQUEST)
def plugin_reload_error(error: str, stacktrace: str) -> web.Response: def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
@ -49,21 +70,21 @@ def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
"error": error, "error": error,
"stacktrace": stacktrace, "stacktrace": stacktrace,
"errcode": "plugin_reload_fail", "errcode": "plugin_reload_fail",
}, status=web.HTTPInternalServerError) }, status=HTTPStatus.INTERNAL_SERVER_ERROR)
ErrUnsupportedPluginLoader = web.json_response({ ErrUnsupportedPluginLoader = web.json_response({
"error": "Existing plugin with same ID uses unsupported plugin loader", "error": "Existing plugin with same ID uses unsupported plugin loader",
"errcode": "unsupported_plugin_loader", "errcode": "unsupported_plugin_loader",
}, status=web.HTTPBadRequest) }, status=HTTPStatus.BAD_REQUEST)
ErrNotImplemented = web.json_response({ ErrNotImplemented = web.json_response({
"error": "Not implemented", "error": "Not implemented",
"errcode": "not_implemented", "errcode": "not_implemented",
}, status=web.HTTPNotImplemented) }, status=HTTPStatus.NOT_IMPLEMENTED)
RespOK = web.json_response({ RespOK = web.json_response({
"success": True, "success": True,
}, status=web.HTTPOk) }, status=HTTPStatus.OK)
RespDeleted = web.Response(status=web.HTTPNoContent) RespDeleted = web.Response(status=HTTPStatus.NO_CONTENT)

View file

@ -27,6 +27,9 @@ if TYPE_CHECKING:
from mautrix.util.config import BaseProxyConfig from mautrix.util.config import BaseProxyConfig
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
class Plugin(ABC): class Plugin(ABC):
client: 'MaubotMatrixClient' client: 'MaubotMatrixClient'
id: str id: str
@ -41,7 +44,9 @@ class Plugin(ABC):
self.config = config self.config = config
self.__db_base_path = db_base_path self.__db_base_path = db_base_path
def request_db_engine(self) -> Engine: def request_db_engine(self) -> Optional[Engine]:
if not self.__db_base_path:
raise DatabaseNotConfigured
return sql.create_engine(f"sqlite:///{os.path.join(self.__db_base_path, self.id)}.db") return sql.create_engine(f"sqlite:///{os.path.join(self.__db_base_path, self.id)}.db")
def set_command_spec(self, spec: 'CommandSpec') -> None: def set_command_spec(self, spec: 'CommandSpec') -> None:

View file

@ -14,6 +14,7 @@
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web from aiohttp import web
import logging
import asyncio import asyncio
from mautrix.api import PathBuilder, Method from mautrix.api import PathBuilder, Method
@ -23,6 +24,8 @@ from .__meta__ import __version__
class MaubotServer: class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server")
def __init__(self, config: Config, management: web.Application, def __init__(self, config: Config, management: web.Application,
loop: asyncio.AbstractEventLoop) -> None: loop: asyncio.AbstractEventLoop) -> None:
self.loop = loop or asyncio.get_event_loop() self.loop = loop or asyncio.get_event_loop()
@ -45,6 +48,7 @@ class MaubotServer:
await self.runner.setup() await self.runner.setup()
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"]) site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
await site.start() await site.start()
self.log.info(f"Listening on {site.name}")
async def stop(self) -> None: async def stop(self) -> None:
await self.runner.cleanup() await self.runner.cleanup()

View file

@ -5,3 +5,4 @@ alembic
commonmark commonmark
ruamel.yaml ruamel.yaml
attrs attrs
bcrypt

View file

@ -28,6 +28,7 @@ setuptools.setup(
"commonmark>=0.8.1,<1", "commonmark>=0.8.1,<1",
"ruamel.yaml>=0.15.35,<0.16", "ruamel.yaml>=0.15.35,<0.16",
"attrs>=18.1.0,<19", "attrs>=18.1.0,<19",
"bcrypt>=3.1.4,<4",
], ],
classifiers=[ classifiers=[