Finish plugin API and add basic login system
This commit is contained in:
parent
d7f072aeff
commit
14fd0d6ac9
16 changed files with 160 additions and 62 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,3 +13,4 @@ __pycache__
|
||||||
|
|
||||||
logs/
|
logs/
|
||||||
plugins/
|
plugins/
|
||||||
|
trash/
|
||||||
|
|
|
@ -30,8 +30,10 @@ server:
|
||||||
# Set to "generate" to generate and save a new token at startup.
|
# Set to "generate" to generate and save a new token at startup.
|
||||||
unshared_secret: generate
|
unshared_secret: generate
|
||||||
|
|
||||||
|
# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password
|
||||||
|
# to prevent normal login. Root is a special user that can't have a password and will always exist.
|
||||||
admins:
|
admins:
|
||||||
- "@admin:example.com"
|
root: ""
|
||||||
|
|
||||||
# Python logging configuration.
|
# Python logging configuration.
|
||||||
#
|
#
|
||||||
|
|
|
@ -14,7 +14,6 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from sqlalchemy import orm
|
from sqlalchemy import orm
|
||||||
from time import time
|
|
||||||
import sqlalchemy as sql
|
import sqlalchemy as sql
|
||||||
import logging.config
|
import logging.config
|
||||||
import argparse
|
import argparse
|
||||||
|
@ -22,7 +21,6 @@ import asyncio
|
||||||
import signal
|
import signal
|
||||||
import copy
|
import copy
|
||||||
import sys
|
import sys
|
||||||
import os
|
|
||||||
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .db import Base, init as init_db
|
from .db import Base, init as init_db
|
||||||
|
|
|
@ -15,9 +15,13 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
|
import bcrypt
|
||||||
|
import re
|
||||||
|
|
||||||
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
|
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
|
||||||
|
|
||||||
|
bcrypt_regex = re.compile(r"^\$2[ayb]\$.{56}$")
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseFileConfig):
|
class Config(BaseFileConfig):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -27,16 +31,35 @@ class Config(BaseFileConfig):
|
||||||
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
def do_update(self, helper: ConfigUpdateHelper) -> None:
|
||||||
base, copy, _ = helper
|
base, copy, _ = helper
|
||||||
copy("database")
|
copy("database")
|
||||||
copy("plugin_directories")
|
copy("plugin_directories.upload")
|
||||||
copy("plugin_db_directory")
|
copy("plugin_directories.load")
|
||||||
|
copy("plugin_directories.trash")
|
||||||
|
copy("plugin_directories.db")
|
||||||
copy("server.hostname")
|
copy("server.hostname")
|
||||||
copy("server.port")
|
copy("server.port")
|
||||||
copy("server.listen")
|
copy("server.listen")
|
||||||
copy("server.base_path")
|
copy("server.appservice_base_path")
|
||||||
shared_secret = self["server.shared_secret"]
|
shared_secret = self["server.unshared_secret"]
|
||||||
if shared_secret is None or shared_secret == "generate":
|
if shared_secret is None or shared_secret == "generate":
|
||||||
base["server.shared_secret"] = self._new_token()
|
base["server.unshared_secret"] = self._new_token()
|
||||||
else:
|
else:
|
||||||
base["server.shared_secret"] = shared_secret
|
base["server.unshared_secret"] = shared_secret
|
||||||
copy("admins")
|
copy("admins")
|
||||||
|
for username, password in base["admins"].items():
|
||||||
|
if password and not bcrypt_regex.match(password):
|
||||||
|
if password == "password":
|
||||||
|
password = self._new_token()
|
||||||
|
base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"),
|
||||||
|
bcrypt.gensalt()).decode("utf-8")
|
||||||
copy("logging")
|
copy("logging")
|
||||||
|
|
||||||
|
def is_admin(self, user: str) -> bool:
|
||||||
|
return user == "root" or user in self["admins"]
|
||||||
|
|
||||||
|
def check_password(self, user: str, passwd: str) -> bool:
|
||||||
|
if user == "root":
|
||||||
|
return False
|
||||||
|
passwd_hash = self["admins"].get(user, None)
|
||||||
|
if not passwd_hash:
|
||||||
|
return False
|
||||||
|
return bcrypt.checkpw(passwd.encode("utf-8"), passwd_hash.encode("utf-8"))
|
||||||
|
|
|
@ -87,13 +87,6 @@ class PluginInstance:
|
||||||
def load_config(self) -> CommentedMap:
|
def load_config(self) -> CommentedMap:
|
||||||
return yaml.load(self.db_instance.config)
|
return yaml.load(self.db_instance.config)
|
||||||
|
|
||||||
def load_config_base(self) -> Optional[RecursiveDict[CommentedMap]]:
|
|
||||||
try:
|
|
||||||
base = self.loader.read_file("base-config.yaml")
|
|
||||||
return RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
|
|
||||||
except (FileNotFoundError, KeyError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
|
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
|
||||||
buf = io.StringIO()
|
buf = io.StringIO()
|
||||||
yaml.dump(data, buf)
|
yaml.dump(data, buf)
|
||||||
|
@ -103,14 +96,23 @@ class PluginInstance:
|
||||||
if not self.enabled:
|
if not self.enabled:
|
||||||
self.log.warning(f"Plugin disabled, not starting.")
|
self.log.warning(f"Plugin disabled, not starting.")
|
||||||
return
|
return
|
||||||
cls = self.loader.load()
|
cls = await self.loader.load()
|
||||||
config_class = cls.get_config_class()
|
config_class = cls.get_config_class()
|
||||||
if config_class:
|
if config_class:
|
||||||
self.config = config_class(self.load_config, self.load_config_base,
|
try:
|
||||||
self.save_config)
|
base = await self.loader.read_file("base-config.yaml")
|
||||||
|
base_file = RecursiveDict(yaml.load(base.decode("utf-8")), CommentedMap)
|
||||||
|
except (FileNotFoundError, KeyError):
|
||||||
|
base_file = None
|
||||||
|
self.config = config_class(self.load_config, lambda: base_file, self.save_config)
|
||||||
self.plugin = cls(self.client.client, self.id, self.log, self.config,
|
self.plugin = cls(self.client.client, self.id, self.log, self.config,
|
||||||
self.mb_config["plugin_db_directory"])
|
self.mb_config["plugin_directories.db"])
|
||||||
|
try:
|
||||||
await self.plugin.start()
|
await self.plugin.start()
|
||||||
|
except Exception:
|
||||||
|
self.log.exception("Failed to start instance")
|
||||||
|
self.enabled = False
|
||||||
|
return
|
||||||
self.running = True
|
self.running = True
|
||||||
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
|
self.log.info(f"Started instance of {self.loader.id} v{self.loader.version} "
|
||||||
f"with user {self.client.id}")
|
f"with user {self.client.id}")
|
||||||
|
|
|
@ -59,10 +59,12 @@ class PluginLoader(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def stop_instances(self) -> None:
|
async def stop_instances(self) -> None:
|
||||||
await asyncio.gather([instance.stop() for instance in self.references if instance.running])
|
await asyncio.gather(*[instance.stop() for instance
|
||||||
|
in self.references if instance.running])
|
||||||
|
|
||||||
async def start_instances(self) -> None:
|
async def start_instances(self) -> None:
|
||||||
await asyncio.gather([instance.start() for instance in self.references if instance.enabled])
|
await asyncio.gather(*[instance.start() for instance
|
||||||
|
in self.references if instance.enabled])
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def load(self) -> Type[PluginClass]:
|
async def load(self) -> Type[PluginClass]:
|
||||||
|
|
|
@ -207,8 +207,10 @@ class ZippedPluginLoader(PluginLoader):
|
||||||
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
|
self.log.debug(f"Loaded and imported plugin {self.id} from {self.path}")
|
||||||
return plugin
|
return plugin
|
||||||
|
|
||||||
async def reload(self) -> Type[PluginClass]:
|
async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]:
|
||||||
await self.unload()
|
await self.unload()
|
||||||
|
if new_path is not None:
|
||||||
|
self.path = new_path
|
||||||
return await self.load(reset_cache=True)
|
return await self.load(reset_cache=True)
|
||||||
|
|
||||||
async def unload(self) -> None:
|
async def unload(self) -> None:
|
||||||
|
|
|
@ -27,8 +27,9 @@ config: Config = None
|
||||||
|
|
||||||
def is_valid_token(token: str) -> bool:
|
def is_valid_token(token: str) -> bool:
|
||||||
data = verify_token(config["server.unshared_secret"], token)
|
data = verify_token(config["server.unshared_secret"], token)
|
||||||
user_id = data.get("user_id", None)
|
if not data:
|
||||||
return user_id is not None and user_id in config["admins"]
|
return False
|
||||||
|
return config.is_admin(data.get("user_id", None))
|
||||||
|
|
||||||
|
|
||||||
def create_token(user: UserID) -> str:
|
def create_token(user: UserID) -> str:
|
||||||
|
@ -40,7 +41,9 @@ def create_token(user: UserID) -> str:
|
||||||
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
|
||||||
global config
|
global config
|
||||||
config = cfg
|
config = cfg
|
||||||
from .middleware import auth, error, log
|
from .middleware import auth, error
|
||||||
app = web.Application(loop=loop, middlewares=[auth, log, error])
|
from .auth import web as _
|
||||||
|
from .plugin import web as _
|
||||||
|
app = web.Application(loop=loop, middlewares=[auth, error])
|
||||||
app.add_routes(routes)
|
app.add_routes(routes)
|
||||||
return app
|
return app
|
||||||
|
|
43
maubot/management/api/auth.py
Normal file
43
maubot/management/api/auth.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
# maubot - A plugin-based Matrix bot system.
|
||||||
|
# Copyright (C) 2018 Tulir Asokan
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from aiohttp import web
|
||||||
|
import json
|
||||||
|
|
||||||
|
from . import routes, config, create_token
|
||||||
|
from .responses import ErrBadAuth, ErrBodyNotJSON
|
||||||
|
|
||||||
|
|
||||||
|
@routes.post("/login")
|
||||||
|
async def login(request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return ErrBodyNotJSON
|
||||||
|
secret = data.get("secret")
|
||||||
|
if secret and config["server.unshared_secret"] == secret:
|
||||||
|
user = data.get("user") or "root"
|
||||||
|
return web.json_response({
|
||||||
|
"token": create_token(user),
|
||||||
|
})
|
||||||
|
|
||||||
|
username = data.get("username")
|
||||||
|
password = data.get("password")
|
||||||
|
if config.check_password(username, password):
|
||||||
|
return web.json_response({
|
||||||
|
"token": create_token(username),
|
||||||
|
})
|
||||||
|
|
||||||
|
return ErrBadAuth
|
|
@ -15,25 +15,21 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from typing import Callable, Awaitable
|
from typing import Callable, Awaitable
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import logging
|
|
||||||
|
|
||||||
from .responses import ErrNoToken, ErrInvalidToken
|
from .responses import ErrNoToken, ErrInvalidToken, ErrPathNotFound, ErrMethodNotAllowed
|
||||||
from . import is_valid_token
|
from . import is_valid_token
|
||||||
|
|
||||||
Handler = Callable[[web.Request], Awaitable[web.Response]]
|
Handler = Callable[[web.Request], Awaitable[web.Response]]
|
||||||
|
|
||||||
req_log = logging.getLogger("maubot.mgmt.request")
|
|
||||||
resp_log = logging.getLogger("maubot.mgmt.response")
|
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
@web.middleware
|
||||||
async def auth(request: web.Request, handler: Handler) -> web.Response:
|
async def auth(request: web.Request, handler: Handler) -> web.Response:
|
||||||
|
if request.path.endswith("/login"):
|
||||||
|
return await handler(request)
|
||||||
token = request.headers.get("Authorization", "")
|
token = request.headers.get("Authorization", "")
|
||||||
if not token or not token.startswith("Bearer "):
|
if not token or not token.startswith("Bearer "):
|
||||||
req_log.debug(f"Request missing auth: {request.remote} {request.method} {request.path}")
|
|
||||||
return ErrNoToken
|
return ErrNoToken
|
||||||
if not is_valid_token(token[len("Bearer "):]):
|
if not is_valid_token(token[len("Bearer "):]):
|
||||||
req_log.debug(f"Request invalid auth: {request.remote} {request.method} {request.path}")
|
|
||||||
return ErrInvalidToken
|
return ErrInvalidToken
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
|
|
||||||
|
@ -43,6 +39,10 @@ async def error(request: web.Request, handler: Handler) -> web.Response:
|
||||||
try:
|
try:
|
||||||
return await handler(request)
|
return await handler(request)
|
||||||
except web.HTTPException as ex:
|
except web.HTTPException as ex:
|
||||||
|
if ex.status_code == 404:
|
||||||
|
return ErrPathNotFound
|
||||||
|
elif ex.status_code == 405:
|
||||||
|
return ErrMethodNotAllowed
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
"error": f"Unhandled HTTP {ex.status}",
|
"error": f"Unhandled HTTP {ex.status}",
|
||||||
"errcode": f"unhandled_http_{ex.status}",
|
"errcode": f"unhandled_http_{ex.status}",
|
||||||
|
@ -56,12 +56,3 @@ def get_req_no():
|
||||||
global req_no
|
global req_no
|
||||||
req_no += 1
|
req_no += 1
|
||||||
return req_no
|
return req_no
|
||||||
|
|
||||||
|
|
||||||
@web.middleware
|
|
||||||
async def log(request: web.Request, handler: Handler) -> web.Response:
|
|
||||||
local_req_no = get_req_no()
|
|
||||||
req_log.info(f"Request {local_req_no}: {request.remote} {request.method} {request.path}")
|
|
||||||
resp = await handler(request)
|
|
||||||
resp_log.info(f"Responded to {local_req_no} from {request.remote}: {resp}")
|
|
||||||
return resp
|
|
||||||
|
|
|
@ -92,25 +92,24 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
|
||||||
if plugin.version in filename:
|
if plugin.version in filename:
|
||||||
filename = filename.replace(plugin.version, new_version)
|
filename = filename.replace(plugin.version, new_version)
|
||||||
else:
|
else:
|
||||||
filename = filename.rstrip(".mbp") + new_version + ".mbp"
|
filename = filename.rstrip(".mbp")
|
||||||
|
filename = f"{filename}-v{new_version}.mbp"
|
||||||
path = os.path.join(dirname, filename)
|
path = os.path.join(dirname, filename)
|
||||||
with open(path, "wb") as p:
|
with open(path, "wb") as p:
|
||||||
p.write(content)
|
p.write(content)
|
||||||
old_path = plugin.path
|
old_path = plugin.path
|
||||||
plugin.path = path
|
|
||||||
await plugin.stop_instances()
|
await plugin.stop_instances()
|
||||||
try:
|
try:
|
||||||
await plugin.reload()
|
await plugin.reload(new_path=path)
|
||||||
except MaubotZipImportError as e:
|
except MaubotZipImportError as e:
|
||||||
plugin.path = old_path
|
|
||||||
try:
|
try:
|
||||||
await plugin.reload()
|
await plugin.reload(new_path=old_path)
|
||||||
|
await plugin.start_instances()
|
||||||
except MaubotZipImportError:
|
except MaubotZipImportError:
|
||||||
pass
|
pass
|
||||||
await plugin.start_instances()
|
|
||||||
return plugin_import_error(str(e), traceback.format_exc())
|
return plugin_import_error(str(e), traceback.format_exc())
|
||||||
await plugin.start_instances()
|
await plugin.start_instances()
|
||||||
ZippedPluginLoader.trash(plugin.path, reason="update")
|
ZippedPluginLoader.trash(old_path, reason="update")
|
||||||
return RespOK
|
return RespOK
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,27 +13,48 @@
|
||||||
#
|
#
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from http import HTTPStatus
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
|
ErrBadAuth = web.json_response({
|
||||||
|
"error": "Invalid username or password",
|
||||||
|
"errcode": "invalid_auth",
|
||||||
|
}, status=HTTPStatus.UNAUTHORIZED)
|
||||||
|
|
||||||
ErrNoToken = web.json_response({
|
ErrNoToken = web.json_response({
|
||||||
"error": "Authorization token missing",
|
"error": "Authorization token missing",
|
||||||
"errcode": "auth_token_missing",
|
"errcode": "auth_token_missing",
|
||||||
}, status=web.HTTPUnauthorized)
|
}, status=HTTPStatus.UNAUTHORIZED)
|
||||||
|
|
||||||
ErrInvalidToken = web.json_response({
|
ErrInvalidToken = web.json_response({
|
||||||
"error": "Invalid authorization token",
|
"error": "Invalid authorization token",
|
||||||
"errcode": "auth_token_invalid",
|
"errcode": "auth_token_invalid",
|
||||||
}, status=web.HTTPUnauthorized)
|
}, status=HTTPStatus.UNAUTHORIZED)
|
||||||
|
|
||||||
ErrPluginNotFound = web.json_response({
|
ErrPluginNotFound = web.json_response({
|
||||||
"error": "Plugin not found",
|
"error": "Plugin not found",
|
||||||
"errcode": "plugin_not_found",
|
"errcode": "plugin_not_found",
|
||||||
}, status=web.HTTPNotFound)
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
ErrPathNotFound = web.json_response({
|
||||||
|
"error": "Resource not found",
|
||||||
|
"errcode": "resource_not_found",
|
||||||
|
}, status=HTTPStatus.NOT_FOUND)
|
||||||
|
|
||||||
|
ErrMethodNotAllowed = web.json_response({
|
||||||
|
"error": "Method not allowed",
|
||||||
|
"errcode": "method_not_allowed",
|
||||||
|
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
|
||||||
|
|
||||||
ErrPluginInUse = web.json_response({
|
ErrPluginInUse = web.json_response({
|
||||||
"error": "Plugin instances of this type still exist",
|
"error": "Plugin instances of this type still exist",
|
||||||
"errcode": "plugin_in_use",
|
"errcode": "plugin_in_use",
|
||||||
}, status=web.HTTPPreconditionFailed)
|
}, status=HTTPStatus.PRECONDITION_FAILED)
|
||||||
|
|
||||||
|
ErrBodyNotJSON = web.json_response({
|
||||||
|
"error": "Request body is not JSON",
|
||||||
|
"errcode": "body_not_json",
|
||||||
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
|
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
|
||||||
|
@ -41,7 +62,7 @@ def plugin_import_error(error: str, stacktrace: str) -> web.Response:
|
||||||
"error": error,
|
"error": error,
|
||||||
"stacktrace": stacktrace,
|
"stacktrace": stacktrace,
|
||||||
"errcode": "plugin_invalid",
|
"errcode": "plugin_invalid",
|
||||||
}, status=web.HTTPBadRequest)
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
|
def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
|
||||||
|
@ -49,21 +70,21 @@ def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
|
||||||
"error": error,
|
"error": error,
|
||||||
"stacktrace": stacktrace,
|
"stacktrace": stacktrace,
|
||||||
"errcode": "plugin_reload_fail",
|
"errcode": "plugin_reload_fail",
|
||||||
}, status=web.HTTPInternalServerError)
|
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
|
||||||
|
|
||||||
|
|
||||||
ErrUnsupportedPluginLoader = web.json_response({
|
ErrUnsupportedPluginLoader = web.json_response({
|
||||||
"error": "Existing plugin with same ID uses unsupported plugin loader",
|
"error": "Existing plugin with same ID uses unsupported plugin loader",
|
||||||
"errcode": "unsupported_plugin_loader",
|
"errcode": "unsupported_plugin_loader",
|
||||||
}, status=web.HTTPBadRequest)
|
}, status=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
ErrNotImplemented = web.json_response({
|
ErrNotImplemented = web.json_response({
|
||||||
"error": "Not implemented",
|
"error": "Not implemented",
|
||||||
"errcode": "not_implemented",
|
"errcode": "not_implemented",
|
||||||
}, status=web.HTTPNotImplemented)
|
}, status=HTTPStatus.NOT_IMPLEMENTED)
|
||||||
|
|
||||||
RespOK = web.json_response({
|
RespOK = web.json_response({
|
||||||
"success": True,
|
"success": True,
|
||||||
}, status=web.HTTPOk)
|
}, status=HTTPStatus.OK)
|
||||||
|
|
||||||
RespDeleted = web.Response(status=web.HTTPNoContent)
|
RespDeleted = web.Response(status=HTTPStatus.NO_CONTENT)
|
||||||
|
|
|
@ -27,6 +27,9 @@ if TYPE_CHECKING:
|
||||||
from mautrix.util.config import BaseProxyConfig
|
from mautrix.util.config import BaseProxyConfig
|
||||||
|
|
||||||
|
|
||||||
|
DatabaseNotConfigured = ValueError("A database for this maubot instance has not been configured.")
|
||||||
|
|
||||||
|
|
||||||
class Plugin(ABC):
|
class Plugin(ABC):
|
||||||
client: 'MaubotMatrixClient'
|
client: 'MaubotMatrixClient'
|
||||||
id: str
|
id: str
|
||||||
|
@ -41,7 +44,9 @@ class Plugin(ABC):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.__db_base_path = db_base_path
|
self.__db_base_path = db_base_path
|
||||||
|
|
||||||
def request_db_engine(self) -> Engine:
|
def request_db_engine(self) -> Optional[Engine]:
|
||||||
|
if not self.__db_base_path:
|
||||||
|
raise DatabaseNotConfigured
|
||||||
return sql.create_engine(f"sqlite:///{os.path.join(self.__db_base_path, self.id)}.db")
|
return sql.create_engine(f"sqlite:///{os.path.join(self.__db_base_path, self.id)}.db")
|
||||||
|
|
||||||
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
def set_command_spec(self, spec: 'CommandSpec') -> None:
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# You should have received a copy of the GNU Affero General Public License
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
from mautrix.api import PathBuilder, Method
|
from mautrix.api import PathBuilder, Method
|
||||||
|
@ -23,6 +24,8 @@ from .__meta__ import __version__
|
||||||
|
|
||||||
|
|
||||||
class MaubotServer:
|
class MaubotServer:
|
||||||
|
log: logging.Logger = logging.getLogger("maubot.server")
|
||||||
|
|
||||||
def __init__(self, config: Config, management: web.Application,
|
def __init__(self, config: Config, management: web.Application,
|
||||||
loop: asyncio.AbstractEventLoop) -> None:
|
loop: asyncio.AbstractEventLoop) -> None:
|
||||||
self.loop = loop or asyncio.get_event_loop()
|
self.loop = loop or asyncio.get_event_loop()
|
||||||
|
@ -45,6 +48,7 @@ class MaubotServer:
|
||||||
await self.runner.setup()
|
await self.runner.setup()
|
||||||
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
|
site = web.TCPSite(self.runner, self.config["server.hostname"], self.config["server.port"])
|
||||||
await site.start()
|
await site.start()
|
||||||
|
self.log.info(f"Listening on {site.name}")
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
await self.runner.cleanup()
|
await self.runner.cleanup()
|
||||||
|
|
|
@ -5,3 +5,4 @@ alembic
|
||||||
commonmark
|
commonmark
|
||||||
ruamel.yaml
|
ruamel.yaml
|
||||||
attrs
|
attrs
|
||||||
|
bcrypt
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -28,6 +28,7 @@ setuptools.setup(
|
||||||
"commonmark>=0.8.1,<1",
|
"commonmark>=0.8.1,<1",
|
||||||
"ruamel.yaml>=0.15.35,<0.16",
|
"ruamel.yaml>=0.15.35,<0.16",
|
||||||
"attrs>=18.1.0,<19",
|
"attrs>=18.1.0,<19",
|
||||||
|
"bcrypt>=3.1.4,<4",
|
||||||
],
|
],
|
||||||
|
|
||||||
classifiers=[
|
classifiers=[
|
||||||
|
|
Loading…
Reference in a new issue