Blacken and isort code

This commit is contained in:
Tulir Asokan 2022-03-25 14:22:37 +02:00
parent 6257979e7c
commit 068e268c63
97 changed files with 1781 additions and 1086 deletions

26
.github/workflows/python-lint.yml vendored Normal file
View file

@ -0,0 +1,26 @@
name: Python lint
on: [push, pull_request]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: "3.10"
- uses: isort/isort-action@master
with:
sortPaths: "./maubot"
- uses: psf/black@stable
with:
src: "./maubot"
version: "22.1.0"
- name: pre-commit
run: |
pip install pre-commit
pre-commit run -av trailing-whitespace
pre-commit run -av end-of-file-fixer
pre-commit run -av check-yaml
pre-commit run -av check-added-large-files

View file

@ -1,4 +1,11 @@
# maubot
![Languages](https://img.shields.io/github/languages/top/maubot/maubot.svg)
[![License](https://img.shields.io/github/license/maubot/maubot.svg)](LICENSE)
[![Release](https://img.shields.io/github/release/maubot/maubot/all.svg)](https://github.com/maubot/maubot/releases)
[![GitLab CI](https://mau.dev/maubot/maubot/badges/master/pipeline.svg)](https://mau.dev/maubot/maubot/container_registry)
[![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
A plugin-based [Matrix](https://matrix.org) bot system written in Python.
## Documentation

3
dev-requirements.txt Normal file
View file

@ -0,0 +1,3 @@
pre-commit>=2.10.1,<3
isort>=5.10.1,<6
black==22.1.0

View file

@ -1,4 +1,4 @@
from .__meta__ import __version__
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
from .plugin_base import Plugin
from .plugin_server import PluginWebApp
from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent
from .__meta__ import __version__

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,15 +17,15 @@ import asyncio
from mautrix.util.program import Program
from .__meta__ import __version__
from .client import Client, init as init_client_class
from .config import Config
from .db import init as init_db
from .server import MaubotServer
from .client import Client, init as init_client_class
from .loader.zip import init as init_zip_loader
from .instance import init as init_plugin_instance_class
from .management.api import init as init_mgmt_api
from .lib.future_awaitable import FutureAwaitable
from .__meta__ import __version__
from .loader.zip import init as init_zip_loader
from .management.api import init as init_mgmt_api
from .server import MaubotServer
class Maubot(Program):
@ -41,6 +41,7 @@ class Maubot(Program):
def prepare_log_websocket(self) -> None:
from .management.api.log import init, stop_all
init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all))

View file

@ -1,2 +1,3 @@
from . import app
app()

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by

View file

@ -1,2 +1,2 @@
from .cliq import command, option
from .validators import SPDXValidator, VersionValidator, PathValidator
from .validators import PathValidator, SPDXValidator, VersionValidator

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,22 +13,23 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Any, Callable, Union, Optional, Type
import functools
import traceback
import inspect
from __future__ import annotations
from typing import Any, Callable
import asyncio
import functools
import inspect
import traceback
import aiohttp
from colorama import Fore
from prompt_toolkit.validation import Validator
from questionary import prompt
from colorama import Fore
import aiohttp
import click
from ..base import app
from ..config import get_token
from .validators import Required, ClickValidator
from .validators import ClickValidator, Required
def with_http(func):
@ -105,7 +106,7 @@ def command(help: str) -> Callable[[Callable], Callable]:
return decorator
def yesno(val: str) -> Optional[bool]:
def yesno(val: str) -> bool | None:
if not val:
return None
elif isinstance(val, bool):
@ -119,11 +120,20 @@ def yesno(val: str) -> Optional[bool]:
yesno.__name__ = "yes/no"
def option(short: str, long: str, message: str = None, help: str = None,
click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None,
validator: Type[Validator] = None, required: bool = False,
default: Union[str, bool, None] = None, is_flag: bool = False, prompt: bool = True,
required_unless: Union[str, list, dict] = None) -> Callable[[Callable], Callable]:
def option(
short: str,
long: str,
message: str = None,
help: str = None,
click_type: str | Callable[[str], Any] = None,
inq_type: str = None,
validator: type[Validator] = None,
required: bool = False,
default: str | bool | None = None,
is_flag: bool = False,
prompt: bool = True,
required_unless: str | list | dict = None,
) -> Callable[[Callable], Callable]:
if not message:
message = long[2].upper() + long[3:]
@ -139,9 +149,9 @@ def option(short: str, long: str, message: str = None, help: str = None,
if not hasattr(func, "__inquirer_questions__"):
func.__inquirer_questions__ = {}
q = {
"type": (inq_type if isinstance(inq_type, str)
else ("input" if not is_flag
else "confirm")),
"type": (
inq_type if isinstance(inq_type, str) else ("input" if not is_flag else "confirm")
),
"name": long[2:],
"message": message,
}

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -16,9 +16,9 @@
from typing import Callable
import os
from packaging.version import Version, InvalidVersion
from prompt_toolkit.validation import Validator, ValidationError
from packaging.version import InvalidVersion, Version
from prompt_toolkit.document import Document
from prompt_toolkit.validation import ValidationError, Validator
import click
from ..util import spdx as spdxlib

View file

@ -1 +1 @@
from . import upload, build, login, init, logs, auth
from . import auth, build, init, login, logs, upload

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,8 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import webbrowser
import json
import webbrowser
from colorama import Fore
from yarl import URL
@ -26,12 +26,16 @@ from ..cliq import cliq
history_count: int = 10
friendly_errors = {
"server_not_found": "Registration target server not found.\n\n"
"To log in or register through maubot, you must add the server to the\n"
"homeservers section in the config. If you only want to log in,\n"
"leave the `secret` field empty.",
"registration_no_sso": "The register operation is only for registering with a password.\n\n"
"To register with SSO, simply leave out the --register flag.",
"server_not_found": (
"Registration target server not found.\n\n"
"To log in or register through maubot, you must add the server to the\n"
"homeservers section in the config. If you only want to log in,\n"
"leave the `secret` field empty."
),
"registration_no_sso": (
"The register operation is only for registering with a password.\n\n"
"To register with SSO, simply leave out the --register flag."
),
}
@ -46,26 +50,58 @@ async def list_servers(server: str, sess: aiohttp.ClientSession) -> None:
@cliq.command(help="Log into a Matrix account via the Maubot server")
@cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list")
@cliq.option("-u", "--username", help="The username to log in with",
required_unless=["list", "sso"])
@cliq.option("-p", "--password", help="The password to log in with", inq_type="password",
required_unless=["list", "sso"])
@cliq.option("-s", "--server", help="The maubot instance to log in through", default="",
required=False, prompt=False)
@click.option("-r", "--register", help="Register instead of logging in", is_flag=True,
default=False)
@click.option("-c", "--update-client", help="Instead of returning the access token, "
"create or update a client in maubot using it",
is_flag=True, default=False)
@cliq.option(
"-u", "--username", help="The username to log in with", required_unless=["list", "sso"]
)
@cliq.option(
"-p",
"--password",
help="The password to log in with",
inq_type="password",
required_unless=["list", "sso"],
)
@cliq.option(
"-s",
"--server",
help="The maubot instance to log in through",
default="",
required=False,
prompt=False,
)
@click.option(
"-r", "--register", help="Register instead of logging in", is_flag=True, default=False
)
@click.option(
"-c",
"--update-client",
help="Instead of returning the access token, " "create or update a client in maubot using it",
is_flag=True,
default=False,
)
@click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False)
@click.option("-o", "--sso", help="Use single sign-on instead of password login",
is_flag=True, default=False)
@click.option("-n", "--device-name", help="The initial e2ee device displayname (only for login)",
default="Maubot", required=False)
@click.option(
"-o", "--sso", help="Use single sign-on instead of password login", is_flag=True, default=False
)
@click.option(
"-n",
"--device-name",
help="The initial e2ee device displayname (only for login)",
default="Maubot",
required=False,
)
@cliq.with_authenticated_http
async def auth(homeserver: str, username: str, password: str, server: str, register: bool,
list: bool, update_client: bool, device_name: str, sso: bool,
sess: aiohttp.ClientSession) -> None:
async def auth(
homeserver: str,
username: str,
password: str,
server: str,
register: bool,
list: bool,
update_client: bool,
device_name: str,
sso: bool,
sess: aiohttp.ClientSession,
) -> None:
if list:
await list_servers(server, sess)
return
@ -88,8 +124,9 @@ async def auth(homeserver: str, username: str, password: str, server: str, regis
await print_response(resp, is_register=register)
async def wait_sso(resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession,
server: str, homeserver: str) -> None:
async def wait_sso(
resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession, server: str, homeserver: str
) -> None:
data = await resp.json()
sso_url, reg_id = data["sso_url"], data["id"]
print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}")
@ -110,9 +147,11 @@ async def print_response(resp: aiohttp.ClientResponse, is_register: bool) -> Non
elif resp.status in (201, 202):
data = await resp.json()
action = "created" if resp.status == 201 else "updated"
print(f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}")
print(
f"{Fore.GREEN}Successfully {action} client for "
f"{Fore.CYAN}{data['id']}{Fore.GREEN} / "
f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}"
)
else:
await print_error(resp, is_register)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,26 +13,28 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Union, IO
from __future__ import annotations
from typing import IO
from io import BytesIO
import zipfile
import asyncio
import glob
import os
import zipfile
from ruamel.yaml import YAML, YAMLError
from aiohttp import ClientSession
from questionary import prompt
from colorama import Fore
from questionary import prompt
from ruamel.yaml import YAML, YAMLError
import click
from mautrix.types import SerializerError
from ...loader import PluginMeta
from ..cliq.validators import PathValidator
from ..base import app
from ..config import get_token
from ..cliq import cliq
from ..cliq.validators import PathValidator
from ..config import get_token
from .upload import upload_file
yaml = YAML()
@ -44,7 +46,7 @@ def zipdir(zip, dir):
zip.write(os.path.join(root, file))
def read_meta(path: str) -> Optional[PluginMeta]:
def read_meta(path: str) -> PluginMeta | None:
try:
with open(os.path.join(path, "maubot.yaml")) as meta_file:
try:
@ -65,7 +67,7 @@ def read_meta(path: str) -> Optional[PluginMeta]:
return meta
def read_output_path(output: str, meta: PluginMeta) -> Optional[str]:
def read_output_path(output: str, meta: PluginMeta) -> str | None:
directory = os.getcwd()
filename = f"{meta.id}-v{meta.version}.mbp"
if not output:
@ -73,18 +75,15 @@ def read_output_path(output: str, meta: PluginMeta) -> Optional[str]:
elif os.path.isdir(output):
output = os.path.join(output, filename)
elif os.path.exists(output):
override = prompt({
"type": "confirm",
"name": "override",
"message": f"{output} exists, override?"
})["override"]
q = [{"type": "confirm", "name": "override", "message": f"{output} exists, override?"}]
override = prompt(q)["override"]
if not override:
return None
os.remove(output)
return os.path.abspath(output)
def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None:
def write_plugin(meta: PluginMeta, output: str | IO) -> None:
with zipfile.ZipFile(output, "w") as zip:
meta_dump = BytesIO()
yaml.dump(meta.serialize(), meta_dump)
@ -104,7 +103,7 @@ def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None:
@cliq.with_authenticated_http
async def upload_plugin(output: Union[str, IO], *, server: str, sess: ClientSession) -> None:
async def upload_plugin(output: str | IO, *, server: str, sess: ClientSession) -> None:
server, token = get_token(server)
if not token:
return
@ -115,14 +114,20 @@ async def upload_plugin(output: Union[str, IO], *, server: str, sess: ClientSess
await upload_file(sess, output, server)
@app.command(short_help="Build a maubot plugin",
help="Build a maubot plugin. First parameter is the path to root of the plugin "
"to build. You can also use --output to specify output file.")
@app.command(
short_help="Build a maubot plugin",
help=(
"Build a maubot plugin. First parameter is the path to root of the plugin "
"to build. You can also use --output to specify output file."
),
)
@click.argument("path", default=os.getcwd())
@click.option("-o", "--output", help="Path to output built plugin to",
type=PathValidator.click_type)
@click.option("-u", "--upload", help="Upload plugin to server after building", is_flag=True,
default=False)
@click.option(
"-o", "--output", help="Path to output built plugin to", type=PathValidator.click_type
)
@click.option(
"-u", "--upload", help="Upload plugin to server after building", is_flag=True, default=False
)
@click.option("-s", "--server", help="Server to upload built plugin to")
def build(path: str, output: str, upload: bool, server: str) -> None:
meta = read_meta(path)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,11 +13,11 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from pkg_resources import resource_string
import os
from packaging.version import Version
from jinja2 import Template
from packaging.version import Version
from pkg_resources import resource_string
from .. import cliq
from ..cliq import SPDXValidator, VersionValidator
@ -40,26 +40,55 @@ def load_templates():
@cliq.command(help="Initialize a new maubot plugin")
@cliq.option("-n", "--name", help="The name of the project", required=True,
default=os.path.basename(os.getcwd()))
@cliq.option("-i", "--id", message="ID", required=True,
help="The maubot plugin ID (Java package name format)")
@cliq.option("-v", "--version", help="Initial version for project (PEP-440 format)",
default="0.1.0", validator=VersionValidator, required=True)
@cliq.option("-l", "--license", validator=SPDXValidator, default="AGPL-3.0-or-later",
help="The license for the project (SPDX identifier)", required=False)
@cliq.option("-c", "--config", message="Should the plugin include a config?",
help="Include a config in the plugin stub", default=False, is_flag=True)
@cliq.option(
"-n",
"--name",
help="The name of the project",
required=True,
default=os.path.basename(os.getcwd()),
)
@cliq.option(
"-i",
"--id",
message="ID",
required=True,
help="The maubot plugin ID (Java package name format)",
)
@cliq.option(
"-v",
"--version",
help="Initial version for project (PEP-440 format)",
default="0.1.0",
validator=VersionValidator,
required=True,
)
@cliq.option(
"-l",
"--license",
validator=SPDXValidator,
default="AGPL-3.0-or-later",
help="The license for the project (SPDX identifier)",
required=False,
)
@cliq.option(
"-c",
"--config",
message="Should the plugin include a config?",
help="Include a config in the plugin stub",
default=False,
is_flag=True,
)
def init(name: str, id: str, version: Version, license: str, config: bool) -> None:
load_templates()
main_class = name[0].upper() + name[1:]
meta = meta_template.render(id=id, version=str(version), license=license, config=config,
main_class=main_class)
meta = meta_template.render(
id=id, version=str(version), license=license, config=config, main_class=main_class
)
with open("maubot.yaml", "w") as file:
file.write(meta)
if license:
with open("LICENSE", "w") as file:
file.write(spdx.get(license)["text"])
file.write(spdx.get(license)["licenseText"])
if not os.path.isdir(name):
os.mkdir(name)
mod = mod_template.render(config=config, name=main_class)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -20,17 +20,39 @@ from colorama import Fore
from yarl import URL
import aiohttp
from ..config import save_config, config
from ..cliq import cliq
from ..config import config, save_config
@cliq.command(help="Log in to a Maubot instance")
@cliq.option("-u", "--username", help="The username of your account", default=os.environ.get("USER", None), required=True)
@cliq.option("-p", "--password", help="The password to your account", inq_type="password", required=True)
@cliq.option("-s", "--server", help="The server to log in to", default="http://localhost:29316", required=True)
@cliq.option("-a", "--alias", help="Alias to reference the server without typing the full URL", default="", required=False)
@cliq.option(
"-u",
"--username",
help="The username of your account",
default=os.environ.get("USER", None),
required=True,
)
@cliq.option(
"-p", "--password", help="The password to your account", inq_type="password", required=True
)
@cliq.option(
"-s",
"--server",
help="The server to log in to",
default="http://localhost:29316",
required=True,
)
@cliq.option(
"-a",
"--alias",
help="Alias to reference the server without typing the full URL",
default="",
required=False,
)
@cliq.with_http
async def login(server: str, username: str, password: str, alias: str, sess: aiohttp.ClientSession) -> None:
async def login(
server: str, username: str, password: str, alias: str, sess: aiohttp.ClientSession
) -> None:
data = {
"username": username,
"password": password,

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -16,14 +16,14 @@
from datetime import datetime
import asyncio
from aiohttp import ClientSession, WSMessage, WSMsgType
from colorama import Fore
from aiohttp import WSMsgType, WSMessage, ClientSession
import click
from mautrix.types import Obj
from ..config import get_token
from ..base import app
from ..config import get_token
history_count: int = 10
@ -50,7 +50,7 @@ def logs(server: str, tail: int) -> None:
def parsedate(entry: Obj) -> None:
i = entry.time.index("+")
i = entry.time.index(":", i)
entry.time = entry.time[:i] + entry.time[i + 1:]
entry.time = entry.time[:i] + entry.time[i + 1 :]
entry.time = datetime.strptime(entry.time, "%Y-%m-%dT%H:%M:%S.%f%z")
@ -66,13 +66,16 @@ levelcolors = {
def print_entry(entry: dict) -> None:
entry = Obj(**entry)
parsedate(entry)
print("{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}"
.format(date=entry.time.strftime("%Y-%m-%d %H:%M:%S"),
level=entry.levelname,
levelcolor=levelcolors.get(entry.levelname, ""),
resetcolor=Fore.RESET,
logger=entry.name,
message=entry.msg))
print(
"{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}".format(
date=entry.time.strftime("%Y-%m-%d %H:%M:%S"),
level=entry.levelname,
levelcolor=levelcolors.get(entry.levelname, ""),
resetcolor=Fore.RESET,
logger=entry.name,
message=entry.msg,
)
)
if entry.exc_info:
print(entry.exc_info)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -43,8 +43,10 @@ async def upload_file(sess: aiohttp.ClientSession, file: IO, server: str) -> Non
async with sess.post(url, data=file, headers=headers) as resp:
if resp.status in (200, 201):
data = await resp.json()
print(f"{Fore.GREEN}Plugin {Fore.CYAN}{data['id']} v{data['version']}{Fore.GREEN} "
f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}")
print(
f"{Fore.GREEN}Plugin {Fore.CYAN}{data['id']} v{data['version']}{Fore.GREEN} "
f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}"
)
else:
try:
err = await resp.json()

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Optional, Dict, Any
from __future__ import annotations
from typing import Any
import json
import os
from colorama import Fore
config: Dict[str, Any] = {
config: dict[str, Any] = {
"servers": {},
"aliases": {},
"default_server": None,
@ -27,9 +29,9 @@ config: Dict[str, Any] = {
configdir = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.environ.get("HOME"), ".config"))
def get_default_server() -> Tuple[Optional[str], Optional[str]]:
def get_default_server() -> tuple[str | None, str | None]:
try:
server: Optional[str] = config["default_server"]
server: str < None = config["default_server"]
except KeyError:
server = None
if server is None:
@ -38,7 +40,7 @@ def get_default_server() -> Tuple[Optional[str], Optional[str]]:
return server, _get_token(server)
def get_token(server: str) -> Tuple[Optional[str], Optional[str]]:
def get_token(server: str) -> tuple[str | None, str | None]:
if not server:
return get_default_server()
if server in config["aliases"]:
@ -46,14 +48,14 @@ def get_token(server: str) -> Tuple[Optional[str], Optional[str]]:
return server, _get_token(server)
def _resolve_alias(alias: str) -> Optional[str]:
def _resolve_alias(alias: str) -> str | None:
try:
return config["aliases"][alias]
except KeyError:
return None
def _get_token(server: str) -> Optional[str]:
def _get_token(server: str) -> str | None:
try:
return config["servers"][server]
except KeyError:

Binary file not shown.

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Optional
import zipfile
import pkg_resources
import json
from __future__ import annotations
spdx_list: Optional[Dict[str, Dict[str, str]]] = None
import json
import zipfile
import pkg_resources
spdx_list: dict[str, dict[str, str]] | None = None
def load() -> None:
@ -31,7 +33,7 @@ def load() -> None:
spdx_list = json.load(file)
def get(id: str) -> Dict[str, str]:
def get(id: str) -> dict[str, str]:
if not spdx_list:
load()
return spdx_list[id.lower()]

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,32 +13,46 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
import asyncio
import logging
from aiohttp import ClientSession
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership,
StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter,
PresenceState, StateFilter, DeviceID)
from mautrix.client import InternalEventType
from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (
ContentURI,
DeviceID,
EventFilter,
EventType,
Filter,
FilterID,
Membership,
PresenceState,
RoomEventFilter,
RoomFilter,
StateEvent,
StateFilter,
StrippedStateEvent,
SyncToken,
UserID,
)
from .lib.store_proxy import SyncStoreProxy
from .db import DBClient
from .lib.store_proxy import SyncStoreProxy
from .matrix import MaubotMatrixClient
try:
from mautrix.crypto import OlmMachine, StateStore as CryptoStateStore, PgCryptoStore
from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
from mautrix.util.async_db import Database as AsyncDatabase
class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
pass
crypto_import_error = None
except ImportError as e:
OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
@ -46,8 +60,8 @@ except ImportError as e:
crypto_import_error = e
if TYPE_CHECKING:
from .instance import PluginInstance
from .config import Config
from .instance import PluginInstance
log = logging.getLogger("maubot.client")
@ -55,20 +69,20 @@ log = logging.getLogger("maubot.client")
class Client:
log: logging.Logger = None
loop: asyncio.AbstractEventLoop = None
cache: Dict[UserID, 'Client'] = {}
cache: dict[UserID, Client] = {}
http_client: ClientSession = None
global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore()
crypto_db: Optional['AsyncDatabase'] = None
global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
crypto_db: AsyncDatabase | None = None
references: Set['PluginInstance']
references: set[PluginInstance]
db_instance: DBClient
client: MaubotMatrixClient
crypto: Optional['OlmMachine']
crypto_store: Optional['PgCryptoStore']
crypto: OlmMachine | None
crypto_store: PgCryptoStore | None
started: bool
remote_displayname: Optional[str]
remote_avatar_url: Optional[ContentURI]
remote_displayname: str | None
remote_avatar_url: ContentURI | None
def __init__(self, db_instance: DBClient) -> None:
self.db_instance = db_instance
@ -79,11 +93,17 @@ class Client:
self.sync_ok = True
self.remote_displayname = None
self.remote_avatar_url = None
self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver,
token=self.access_token, client_session=self.http_client,
log=self.log, loop=self.loop, device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store)
self.client = MaubotMatrixClient(
mxid=self.id,
base_url=self.homeserver,
token=self.access_token,
client_session=self.http_client,
log=self.log,
loop=self.loop,
device_id=self.device_id,
sync_store=SyncStoreProxy(self.db_instance),
state_store=self.global_state_store,
)
if self.enable_crypto:
self._prepare_crypto()
else:
@ -104,8 +124,10 @@ class Client:
return False
elif not OlmMachine:
global crypto_import_error
self.log.warning("Client has device ID, but encryption dependencies not installed",
exc_info=crypto_import_error)
self.log.warning(
"Client has device ID, but encryption dependencies not installed",
exc_info=crypto_import_error,
)
# Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None
return False
@ -115,8 +137,9 @@ class Client:
return True
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(account_id=self.id, pickle_key="mau.crypto",
db=self.crypto_db)
self.crypto_store = PgCryptoStore(
account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
)
self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
@ -133,13 +156,13 @@ class Client:
for event_type, func in handlers:
self.client.remove_event_handler(event_type, func)
def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]:
async def handler(data: Dict[str, Any]) -> None:
def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
async def handler(data: dict[str, Any]) -> None:
self.sync_ok = ok
return handler
async def start(self, try_n: Optional[int] = 0) -> None:
async def start(self, try_n: int | None = 0) -> None:
try:
if try_n > 0:
await asyncio.sleep(try_n * 10)
@ -152,15 +175,16 @@ class Client:
await self.crypto_store.open()
crypto_device_id = await self.crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != self.device_id:
self.log.warning("Mismatching device ID in crypto store and main database, "
"resetting encryption")
self.log.warning(
"Mismatching device ID in crypto store and main database, " "resetting encryption"
)
await self.crypto_store.delete()
crypto_device_id = None
await self.crypto.load()
if not crypto_device_id:
await self.crypto_store.put_device_id(self.device_id)
async def _start(self, try_n: Optional[int] = 0) -> None:
async def _start(self, try_n: int | None = 0) -> None:
if not self.enabled:
self.log.debug("Not starting disabled client")
return
@ -179,8 +203,9 @@ class Client:
self.log.exception("Failed to get /account/whoami, disabling client")
self.db_instance.enabled = False
else:
self.log.warning(f"Failed to get /account/whoami, "
f"retrying in {(try_n + 1) * 10}s: {e}")
self.log.warning(
f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
)
_ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
return
if whoami.user_id != self.id:
@ -188,25 +213,30 @@ class Client:
self.db_instance.enabled = False
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(f"Device ID mismatch: expected {self.device_id}, "
f"but got {whoami.device_id}")
self.log.error(
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
)
self.db_instance.enabled = False
return
if not self.filter_id:
self.db_instance.edit(filter_id=await self.client.create_filter(Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
lazy_load_members=True,
),
state=StateFilter(
lazy_load_members=True,
self.db_instance.edit(
filter_id=await self.client.create_filter(
Filter(
room=RoomFilter(
timeline=RoomEventFilter(
limit=50,
lazy_load_members=True,
),
state=StateFilter(
lazy_load_members=True,
),
),
presence=EventFilter(
not_types=[EventType.PRESENCE],
),
)
),
presence=EventFilter(
not_types=[EventType.PRESENCE],
),
)))
)
)
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
@ -258,8 +288,9 @@ class Client:
"homeserver": self.homeserver,
"access_token": self.access_token,
"device_id": self.device_id,
"fingerprint": (self.crypto.account.fingerprint if self.crypto and self.crypto.account
else None),
"fingerprint": (
self.crypto.account.fingerprint if self.crypto and self.crypto.account else None
),
"enabled": self.enabled,
"started": self.started,
"sync": self.sync,
@ -274,7 +305,7 @@ class Client:
}
@classmethod
def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']:
def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
try:
return cls.cache[user_id]
except KeyError:
@ -284,7 +315,7 @@ class Client:
return Client(db_instance)
@classmethod
def all(cls) -> Iterable['Client']:
def all(cls) -> Iterable[Client]:
return (cls.get(user.id, user) for user in DBClient.all())
async def _handle_tombstone(self, evt: StateEvent) -> None:
@ -324,8 +355,12 @@ class Client:
else:
await self._update_remote_profile()
async def update_access_details(self, access_token: Optional[str], homeserver: Optional[str],
device_id: Optional[str] = None) -> None:
async def update_access_details(
self,
access_token: str | None,
homeserver: str | None,
device_id: str | None = None,
) -> None:
if not access_token and not homeserver:
return
if device_id is None:
@ -338,10 +373,16 @@ class Client:
and device_id == self.device_id
):
return
new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver,
token=access_token or self.access_token, loop=self.loop,
device_id=device_id, client_session=self.http_client,
log=self.log, state_store=self.global_state_store)
new_client = MaubotMatrixClient(
mxid=self.id,
base_url=homeserver or self.homeserver,
token=access_token or self.access_token,
loop=self.loop,
device_id=device_id,
client_session=self.http_client,
log=self.log,
state_store=self.global_state_store,
)
whoami = await new_client.whoami()
if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}")
@ -455,7 +496,7 @@ class Client:
# endregion
def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
Client.http_client = ClientSession(loop=loop)
Client.loop = loop

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,9 +14,10 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
import random
import string
import bcrypt
import re
import string
import bcrypt
from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper
@ -64,8 +65,9 @@ class Config(BaseFileConfig):
if password and not bcrypt_regex.match(password):
if password == "password":
password = self._new_token()
base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"),
bcrypt.gensalt()).decode("utf-8")
base["admins"][username] = bcrypt.hashpw(
password.encode("utf-8"), bcrypt.gensalt()
).decode("utf-8")
copy("api_features.login")
copy("api_features.plugin")
copy("api_features.plugin_upload")

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,13 +17,13 @@ from typing import Iterable, Optional
import logging
import sys
from sqlalchemy import Column, String, Boolean, ForeignKey, Text
from sqlalchemy import Boolean, Column, ForeignKey, String, Text
from sqlalchemy.engine.base import Engine
import sqlalchemy as sql
from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI
from mautrix.util.db import Base
from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
from mautrix.util.db import Base
from .config import Config
@ -34,17 +34,19 @@ class DBPlugin(Base):
id: str = Column(String(255), primary_key=True)
type: str = Column(String(255), nullable=False)
enabled: bool = Column(Boolean, nullable=False, default=False)
primary_user: UserID = Column(String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
nullable=False)
config: str = Column(Text, nullable=False, default='')
primary_user: UserID = Column(
String(255),
ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
nullable=False,
)
config: str = Column(Text, nullable=False, default="")
@classmethod
def all(cls) -> Iterable['DBPlugin']:
def all(cls) -> Iterable["DBPlugin"]:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional['DBPlugin']:
def get(cls, id: str) -> Optional["DBPlugin"]:
return cls._select_one_or_none(cls.c.id == id)
@ -68,11 +70,11 @@ class DBClient(Base):
avatar_url: ContentURI = Column(String(255), nullable=False, default="")
@classmethod
def all(cls) -> Iterable['DBClient']:
def all(cls) -> Iterable["DBClient"]:
return cls._select_all()
@classmethod
def get(cls, id: str) -> Optional['DBClient']:
def get(cls, id: str) -> Optional["DBClient"]:
return cls._select_one_or_none(cls.c.id == id)
@ -87,15 +89,20 @@ def init(config: Config) -> Engine:
log = logging.getLogger("maubot.db")
if db.has_table("client") and db.has_table("plugin"):
log.warning("alembic_version table not found, but client and plugin tables found. "
"Assuming pre-Alembic database and inserting version.")
db.execute("CREATE TABLE IF NOT EXISTS alembic_version ("
" version_num VARCHAR(32) PRIMARY KEY"
");")
log.warning(
"alembic_version table not found, but client and plugin tables found. "
"Assuming pre-Alembic database and inserting version."
)
db.execute(
"CREATE TABLE IF NOT EXISTS alembic_version ("
" version_num VARCHAR(32) PRIMARY KEY"
");"
)
db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
else:
log.critical("alembic_version table not found. "
"Did you forget to `alembic upgrade head`?")
log.critical(
"alembic_version table not found. " "Did you forget to `alembic upgrade head`?"
)
sys.exit(10)
return db

View file

@ -1 +1 @@
from . import event, command, web
from . import command, event, web

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,29 +13,46 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List,
Dict, Tuple, Set, Iterable)
from typing import (
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
NewType,
Optional,
Pattern,
Sequence,
Set,
Tuple,
Union,
)
from abc import ABC, abstractmethod
import asyncio
import functools
import inspect
import re
from mautrix.types import MessageType, EventType
from mautrix.types import EventType, MessageType
from ..matrix import MaubotMessageEvent
from . import event
PrefixType = Optional[Union[str, Callable[[], str], Callable[[Any], str]]]
AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool],
Callable[[Any, str], bool]]
CommandHandlerFunc = NewType("CommandHandlerFunc",
Callable[[MaubotMessageEvent, Any], Awaitable[Any]])
CommandHandlerDecorator = NewType("CommandHandlerDecorator",
Callable[[Union['CommandHandler', CommandHandlerFunc]],
'CommandHandler'])
PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator",
Callable[[CommandHandlerFunc], CommandHandlerFunc])
AliasesType = Union[
List[str], Tuple[str, ...], Set[str], Callable[[str], bool], Callable[[Any, str], bool]
]
CommandHandlerFunc = NewType(
"CommandHandlerFunc", Callable[[MaubotMessageEvent, Any], Awaitable[Any]]
)
CommandHandlerDecorator = NewType(
"CommandHandlerDecorator",
Callable[[Union["CommandHandler", CommandHandlerFunc]], "CommandHandler"],
)
PassiveCommandHandlerDecorator = NewType(
"PassiveCommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc]
)
def _split_in_two(val: str, split_by: str) -> List[str]:
@ -67,15 +84,26 @@ class CommandHandler:
return self.__bound_copies__[instance]
except KeyError:
new_ch = type(self)(self.__mb_func__)
keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match",
"require_subcommand", "arg_fallthrough", "event_handler", "event_type",
"msgtypes"]
keys = [
"parent",
"subcommands",
"arguments",
"help",
"get_name",
"is_command_match",
"require_subcommand",
"arg_fallthrough",
"event_handler",
"event_type",
"msgtypes",
]
for key in keys:
key = f"__mb_{key}__"
setattr(new_ch, key, getattr(self, key))
new_ch.__bound_instance__ = instance
new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype)
for subcmd in self.__mb_subcommands__]
new_ch.__mb_subcommands__ = [
subcmd.__get__(instance, instancetype) for subcmd in self.__mb_subcommands__
]
self.__bound_copies__[instance] = new_ch
return new_ch
@ -83,8 +111,13 @@ class CommandHandler:
def __command_match_unset(self, val: str) -> bool:
raise NotImplementedError("Hmm")
async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None,
remaining_val: str = None) -> Any:
async def __call__(
self,
evt: MaubotMessageEvent,
*,
_existing_args: Dict[str, Any] = None,
remaining_val: str = None,
) -> Any:
if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__:
return
if remaining_val is None:
@ -120,21 +153,25 @@ class CommandHandler:
return await self.__mb_func__(self.__bound_instance__, evt, **call_args)
return await self.__mb_func__(evt, **call_args)
async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, Any]:
async def __call_subcommand__(
self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str
) -> Tuple[bool, Any]:
command, remaining_val = _split_in_two(remaining_val.strip(), " ")
for subcommand in self.__mb_subcommands__:
if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command):
return True, await subcommand(evt, _existing_args=call_args,
remaining_val=remaining_val)
return True, await subcommand(
evt, _existing_args=call_args, remaining_val=remaining_val
)
return False, None
async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any],
remaining_val: str) -> Tuple[bool, str]:
async def __parse_args__(
self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str
) -> Tuple[bool, str]:
for arg in self.__mb_arguments__:
try:
remaining_val, call_args[arg.name] = arg.match(remaining_val.strip(), evt=evt,
instance=self.__bound_instance__)
remaining_val, call_args[arg.name] = arg.match(
remaining_val.strip(), evt=evt, instance=self.__bound_instance__
)
if arg.required and call_args[arg.name] is None:
raise ValueError("Argument required")
except ArgumentSyntaxError as e:
@ -155,8 +192,9 @@ class CommandHandler:
@property
def __mb_usage_args__(self) -> str:
arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]"
for arg in self.__mb_arguments__)
arg_usage = " ".join(
f"<{arg.label}>" if arg.required else f"[{arg.label}]" for arg in self.__mb_arguments__
)
if self.__mb_subcommands__ and self.__mb_arg_fallthrough__:
arg_usage += " " + self.__mb_usage_subcommand__
return arg_usage
@ -172,15 +210,19 @@ class CommandHandler:
@property
def __mb_prefix__(self) -> str:
if self.__mb_parent__:
return (f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} "
f"{self.__mb_name__}")
return (
f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} "
f"{self.__mb_name__}"
)
return f"!{self.__mb_name__}"
@property
def __mb_usage_inline__(self) -> str:
if not self.__mb_arg_fallthrough__:
return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}")
return (
f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n"
f"* {self.__mb_name__} {self.__mb_usage_subcommand__}"
)
return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}"
@property
@ -192,8 +234,10 @@ class CommandHandler:
if not self.__mb_arg_fallthrough__:
if not self.__mb_arguments__:
return f"**Usage:** {self.__mb_prefix__} [subcommand] [...]"
return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}")
return (
f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}"
)
return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}"
@property
@ -202,14 +246,25 @@ class CommandHandler:
return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}"
return self.__mb_usage_without_subcommands__
def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
required_subcommand: bool = True, arg_fallthrough: bool = True,
) -> CommandHandlerDecorator:
def subcommand(
self,
name: PrefixType = None,
*,
help: str = None,
aliases: AliasesType = None,
required_subcommand: bool = True,
arg_fallthrough: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
new(name, help=help, aliases=aliases, require_subcommand=required_subcommand,
arg_fallthrough=arg_fallthrough)(func)
new(
name,
help=help,
aliases=aliases,
require_subcommand=required_subcommand,
arg_fallthrough=arg_fallthrough,
)(func)
func.__mb_parent__ = self
func.__mb_event_handler__ = False
self.__mb_subcommands__.append(func)
@ -218,10 +273,17 @@ class CommandHandler:
return decorator
def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None,
event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: Iterable[MessageType] = None,
require_subcommand: bool = True, arg_fallthrough: bool = True,
must_consume_args: bool = True) -> CommandHandlerDecorator:
def new(
name: PrefixType = None,
*,
help: str = None,
aliases: AliasesType = None,
event_type: EventType = EventType.ROOM_MESSAGE,
msgtypes: Iterable[MessageType] = None,
require_subcommand: bool = True,
arg_fallthrough: bool = True,
must_consume_args: bool = True,
) -> CommandHandlerDecorator:
def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler:
if not isinstance(func, CommandHandler):
func = CommandHandler(func)
@ -242,8 +304,9 @@ def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = Non
else:
func.__mb_is_command_match__ = aliases
elif isinstance(aliases, (list, set, tuple)):
func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_get_name__(self)
or val in aliases)
func.__mb_is_command_match__ = lambda self, val: (
val == func.__mb_get_name__(self) or val in aliases
)
else:
func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self)
# Decorators are executed last to first, so we reverse the argument list.
@ -267,8 +330,9 @@ class ArgumentSyntaxError(ValueError):
class Argument(ABC):
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False) -> None:
def __init__(
self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False
) -> None:
self.name = name
self.label = label or name
self.required = required
@ -286,8 +350,15 @@ class Argument(ABC):
class RegexArgument(Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False, matches: str = None) -> None:
def __init__(
self,
name: str,
label: str = None,
*,
required: bool = False,
pass_raw: bool = False,
matches: str = None,
) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw)
matches = f"^{matches}" if self.pass_raw else f"^{matches}$"
self.regex = re.compile(matches)
@ -298,14 +369,23 @@ class RegexArgument(Argument):
val = re.split(r"\s", val, 1)[0]
match = self.regex.match(val)
if match:
return (orig_val[:match.start()] + orig_val[match.end():],
match.groups() or val[match.start():match.end()])
return (
orig_val[: match.start()] + orig_val[match.end() :],
match.groups() or val[match.start() : match.end()],
)
return orig_val, None
class CustomArgument(Argument):
def __init__(self, name: str, label: str = None, *, required: bool = False,
pass_raw: bool = False, matcher: Callable[[str], Any]) -> None:
def __init__(
self,
name: str,
label: str = None,
*,
required: bool = False,
pass_raw: bool = False,
matcher: Callable[[str], Any],
) -> None:
super().__init__(name, label, required=required, pass_raw=pass_raw)
self.matcher = matcher
@ -316,7 +396,7 @@ class CustomArgument(Argument):
val = re.split(r"\s", val, 1)[0]
res = self.matcher(val)
if res is not None:
return orig_val[len(val):], res
return orig_val[len(val) :], res
return orig_val, None
@ -325,12 +405,18 @@ class SimpleArgument(Argument):
if self.pass_raw:
return "", val
res = re.split(r"\s", val, 1)[0]
return val[len(res):], res
return val[len(res) :], res
def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False
) -> CommandHandlerDecorator:
def argument(
name: str,
label: str = None,
*,
required: bool = True,
matches: Optional[str] = None,
parser: Optional[Callable[[str], Any]] = None,
pass_raw: bool = False,
) -> CommandHandlerDecorator:
if matches:
return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw)
elif parser:
@ -339,11 +425,17 @@ def argument(name: str, label: str = None, *, required: bool = True, matches: Op
return SimpleArgument(name, label, required=required, pass_raw=pass_raw)
def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False,
case_insensitive: bool = False, multiline: bool = False, dot_all: bool = False
) -> PassiveCommandHandlerDecorator:
def passive(
regex: Union[str, Pattern],
*,
msgtypes: Sequence[MessageType] = (MessageType.TEXT,),
field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body,
event_type: EventType = EventType.ROOM_MESSAGE,
multiple: bool = False,
case_insensitive: bool = False,
multiline: bool = False,
dot_all: bool = False,
) -> PassiveCommandHandlerDecorator:
if not isinstance(regex, Pattern):
flags = re.RegexFlag.UNICODE
if case_insensitive:
@ -372,12 +464,14 @@ def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (Me
return
data = field(evt)
if multiple:
val = [(data[match.pos:match.endpos], *match.groups())
for match in regex.finditer(data)]
val = [
(data[match.pos : match.endpos], *match.groups())
for match in regex.finditer(data)
]
else:
match = regex.search(data)
if match:
val = (data[match.pos:match.endpos], *match.groups())
val = (data[match.pos : match.endpos], *match.groups())
else:
val = None
if val:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,16 +13,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Union, NewType
from __future__ import annotations
from typing import Callable, NewType
from mautrix.types import EventType
from mautrix.client import EventHandler, InternalEventType
from mautrix.types import EventType
EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler])
def on(var: Union[EventType, InternalEventType, EventHandler]
) -> Union[EventHandlerDecorator, EventHandler]:
def on(var: EventType | InternalEventType | EventHandler) -> EventHandlerDecorator | EventHandler:
def decorator(func: EventHandler) -> EventHandler:
func.__mb_event_handler__ = True
if isinstance(var, (EventType, InternalEventType)):

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,9 +13,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Any, Awaitable
from typing import Any, Awaitable, Callable
from aiohttp import web, hdrs
from aiohttp import hdrs, web
WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]]
WebHandlerDecorator = Callable[[WebHandler], WebHandler]

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,22 +13,24 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Optional, Iterable, TYPE_CHECKING
from asyncio import AbstractEventLoop
import os.path
import logging
import io
from __future__ import annotations
from typing import TYPE_CHECKING, Iterable
from asyncio import AbstractEventLoop
import io
import logging
import os.path
from ruamel.yaml.comments import CommentedMap
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.types import UserID
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from .db import DBPlugin
from .config import Config
from .client import Client
from .config import Config
from .db import DBPlugin
from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
@ -43,23 +45,23 @@ yaml.width = 200
class PluginInstance:
webserver: 'MaubotServer' = None
webserver: MaubotServer = None
mb_config: Config = None
loop: AbstractEventLoop = None
cache: Dict[str, 'PluginInstance'] = {}
plugin_directories: List[str] = []
cache: dict[str, PluginInstance] = {}
plugin_directories: list[str] = []
log: logging.Logger
loader: PluginLoader
client: Client
plugin: Plugin
config: BaseProxyConfig
base_cfg: Optional[RecursiveDict[CommentedMap]]
base_cfg_str: Optional[str]
base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None
inst_db: sql.engine.Engine
inst_db_tables: Dict[str, sql.Table]
inst_webapp: Optional['PluginWebApp']
inst_webapp_url: Optional[str]
inst_db_tables: dict[str, sql.Table]
inst_webapp: PluginWebApp | None
inst_webapp_url: str | None
started: bool
def __init__(self, db_instance: DBPlugin):
@ -87,11 +89,12 @@ class PluginInstance:
"primary_user": self.primary_user,
"config": self.db_instance.config,
"base_config": self.base_cfg_str,
"database": (self.inst_db is not None
and self.mb_config["api_features.instance_database"]),
"database": (
self.inst_db is not None and self.mb_config["api_features.instance_database"]
),
}
def get_db_tables(self) -> Dict[str, sql.Table]:
def get_db_tables(self) -> dict[str, sql.Table]:
if not self.inst_db_tables:
metadata = sql.MetaData()
metadata.reflect(self.inst_db)
@ -147,7 +150,8 @@ class PluginInstance:
self.inst_db.dispose()
ZippedPluginLoader.trash(
os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted")
reason="deleted",
)
if self.inst_webapp:
self.disable_webapp()
@ -194,13 +198,23 @@ class PluginInstance:
if self.base_cfg:
base_cfg_func = self.base_cfg.clone
else:
def base_cfg_func() -> None:
return None
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client,
instance_id=self.id, log=self.log, config=self.config,
database=self.inst_db, loader=self.loader, webapp=self.inst_webapp,
webapp_url=self.inst_webapp_url)
self.plugin = cls(
client=self.client.client,
loop=self.loop,
http=self.client.http_client,
instance_id=self.id,
log=self.log,
config=self.config,
database=self.inst_db,
loader=self.loader,
webapp=self.inst_webapp,
webapp_url=self.inst_webapp_url,
)
try:
await self.plugin.internal_start()
except Exception:
@ -209,8 +223,10 @@ class PluginInstance:
return
self.started = True
self.inst_db_tables = None
self.log.info(f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} "
f"with user {self.client.id}")
self.log.info(
f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} "
f"with user {self.client.id}"
)
async def stop(self) -> None:
if not self.started:
@ -226,8 +242,7 @@ class PluginInstance:
self.inst_db_tables = None
@classmethod
def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None
) -> Optional['PluginInstance']:
def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
try:
return cls.cache[instance_id]
except KeyError:
@ -237,7 +252,7 @@ class PluginInstance:
return PluginInstance(db_instance)
@classmethod
def all(cls) -> Iterable['PluginInstance']:
def all(cls) -> Iterable[PluginInstance]:
return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
def update_id(self, new_id: str) -> None:
@ -317,8 +332,9 @@ class PluginInstance:
# endregion
def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop
) -> Iterable[PluginInstance]:
def init(
config: Config, webserver: MaubotServer, loop: AbstractEventLoop
) -> Iterable[PluginInstance]:
PluginInstance.mb_config = config
PluginInstance.loop = loop
PluginInstance.webserver = webserver

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2020 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,8 +13,13 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from mautrix.util.logging.color import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR,
MXID_COLOR, RESET)
from mautrix.util.logging.color import (
MAU_COLOR,
MXID_COLOR,
PREFIX,
RESET,
ColorFormatter as BaseColorFormatter,
)
INST_COLOR = PREFIX + "35m" # magenta
LOADER_COLOR = PREFIX + "36m" # blue

View file

@ -1,4 +1,5 @@
from typing import Callable, Awaitable, Generator, Any
from typing import Any, Awaitable, Callable, Generator
class FutureAwaitable:
def __init__(self, func: Callable[[], Awaitable[None]]) -> None:
@ -6,4 +7,3 @@ class FutureAwaitable:
def __await__(self) -> Generator[Any, None, None]:
return self._func().__await__()

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by

View file

@ -18,26 +18,28 @@ used by the builtin import mechanism for sys.path items that are paths
to Zip archives.
"""
from importlib import _bootstrap_external
from importlib import _bootstrap # for _verbose_message
import _imp # for check_hash_based_pycs
import _io # for open
from importlib import _bootstrap_external
import marshal # for loads
import sys # for modules
import time # for mktime
__all__ = ['ZipImportError', 'zipimporter']
import _imp # for check_hash_based_pycs
import _io # for open
__all__ = ["ZipImportError", "zipimporter"]
def _unpack_uint32(data):
"""Convert 4 bytes in little-endian to an integer."""
assert len(data) == 4
return int.from_bytes(data, 'little')
return int.from_bytes(data, "little")
def _unpack_uint16(data):
"""Convert 2 bytes in little-endian to an integer."""
assert len(data) == 2
return int.from_bytes(data, 'little')
return int.from_bytes(data, "little")
path_sep = _bootstrap_external.path_sep
@ -47,15 +49,17 @@ alt_path_sep = _bootstrap_external.path_separators[1:]
class ZipImportError(ImportError):
pass
# _read_directory() cache
_zip_directory_cache = {}
_module_type = type(sys)
END_CENTRAL_DIR_SIZE = 22
STRING_END_ARCHIVE = b'PK\x05\x06'
STRING_END_ARCHIVE = b"PK\x05\x06"
MAX_COMMENT_LEN = (1 << 16) - 1
class zipimporter:
"""zipimporter(archivepath) -> zipimporter object
@ -77,9 +81,10 @@ class zipimporter:
def __init__(self, path):
if not isinstance(path, str):
import os
path = os.fsdecode(path)
if not path:
raise ZipImportError('archive path is empty', path=path)
raise ZipImportError("archive path is empty", path=path)
if alt_path_sep:
path = path.replace(alt_path_sep, path_sep)
@ -92,14 +97,14 @@ class zipimporter:
# Back up one path element.
dirname, basename = _bootstrap_external._path_split(path)
if dirname == path:
raise ZipImportError('not a Zip file', path=path)
raise ZipImportError("not a Zip file", path=path)
path = dirname
prefix.append(basename)
else:
# it exists
if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG
# it's a not file
raise ZipImportError('not a Zip file', path=path)
raise ZipImportError("not a Zip file", path=path)
break
try:
@ -154,11 +159,10 @@ class zipimporter:
# This is possibly a portion of a namespace
# package. Return the string representing its path,
# without a trailing separator.
return None, [f'{self.archive}{path_sep}{modpath}']
return None, [f"{self.archive}{path_sep}{modpath}"]
return None, []
# Check whether we can satisfy the import of the module named by
# 'fullname'. Return self if we can, None if we can't.
def find_module(self, fullname, path=None):
@ -172,7 +176,6 @@ class zipimporter:
"""
return self.find_loader(fullname, path)[0]
def get_code(self, fullname):
"""get_code(fullname) -> code object.
@ -182,7 +185,6 @@ class zipimporter:
code, ispackage, modpath = _get_module_code(self, fullname)
return code
def get_data(self, pathname):
"""get_data(pathname) -> string with file data.
@ -194,15 +196,14 @@ class zipimporter:
key = pathname
if pathname.startswith(self.archive + path_sep):
key = pathname[len(self.archive + path_sep):]
key = pathname[len(self.archive + path_sep) :]
try:
toc_entry = self._files[key]
except KeyError:
raise OSError(0, '', key)
raise OSError(0, "", key)
return _get_data(self.archive, toc_entry)
# Return a string matching __file__ for the named module
def get_filename(self, fullname):
"""get_filename(fullname) -> filename string.
@ -214,7 +215,6 @@ class zipimporter:
code, ispackage, modpath = _get_module_code(self, fullname)
return modpath
def get_source(self, fullname):
"""get_source(fullname) -> source string.
@ -228,9 +228,9 @@ class zipimporter:
path = _get_module_path(self, fullname)
if mi:
fullpath = _bootstrap_external._path_join(path, '__init__.py')
fullpath = _bootstrap_external._path_join(path, "__init__.py")
else:
fullpath = f'{path}.py'
fullpath = f"{path}.py"
try:
toc_entry = self._files[fullpath]
@ -239,7 +239,6 @@ class zipimporter:
return None
return _get_data(self.archive, toc_entry).decode()
# Return a bool signifying whether the module is a package or not.
def is_package(self, fullname):
"""is_package(fullname) -> bool.
@ -252,7 +251,6 @@ class zipimporter:
raise ZipImportError(f"can't find module {fullname!r}", name=fullname)
return mi
# Load and return the module named by 'fullname'.
def load_module(self, fullname):
"""load_module(fullname) -> module.
@ -276,7 +274,7 @@ class zipimporter:
fullpath = _bootstrap_external._path_join(self.archive, path)
mod.__path__ = [fullpath]
if not hasattr(mod, '__builtins__'):
if not hasattr(mod, "__builtins__"):
mod.__builtins__ = __builtins__
_bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath)
exec(code, mod.__dict__)
@ -287,11 +285,10 @@ class zipimporter:
try:
mod = sys.modules[fullname]
except KeyError:
raise ImportError(f'Loaded module {fullname!r} not found in sys.modules')
_bootstrap._verbose_message('import {} # loaded from Zip {}', fullname, modpath)
raise ImportError(f"Loaded module {fullname!r} not found in sys.modules")
_bootstrap._verbose_message("import {} # loaded from Zip {}", fullname, modpath)
return mod
def get_resource_reader(self, fullname):
"""Return the ResourceReader for a package in a zip file.
@ -305,11 +302,11 @@ class zipimporter:
return None
if not _ZipImportResourceReader._registered:
from importlib.abc import ResourceReader
ResourceReader.register(_ZipImportResourceReader)
_ZipImportResourceReader._registered = True
return _ZipImportResourceReader(self, fullname)
def __repr__(self):
return f'<zipimporter object "{self.archive}{path_sep}{self.prefix}">'
@ -320,16 +317,17 @@ class zipimporter:
# are swapped by initzipimport() if we run in optimized mode. Also,
# '/' is replaced by path_sep there.
_zip_searchorder = (
(path_sep + '__init__.pyc', True, True),
(path_sep + '__init__.py', False, True),
('.pyc', True, False),
('.py', False, False),
(path_sep + "__init__.pyc", True, True),
(path_sep + "__init__.py", False, True),
(".pyc", True, False),
(".py", False, False),
)
# Given a module name, return the potential file path in the
# archive (without extension).
def _get_module_path(self, fullname):
return self.prefix + fullname.rpartition('.')[2]
return self.prefix + fullname.rpartition(".")[2]
# Does this path represent a directory?
def _is_dir(self, path):
@ -340,6 +338,7 @@ def _is_dir(self, path):
# If dirpath is present in self._files, we have a directory.
return dirpath in self._files
# Return some information about a module.
def _get_module_info(self, fullname):
path = _get_module_path(self, fullname)
@ -374,7 +373,7 @@ def _get_module_info(self, fullname):
# data_size and file_offset are 0.
def _read_directory(archive):
try:
fp = _io.open(archive, 'rb')
fp = _io.open(archive, "rb")
except OSError:
raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive)
@ -394,36 +393,33 @@ def _read_directory(archive):
fp.seek(0, 2)
file_size = fp.tell()
except OSError:
raise ZipImportError(f"can't read Zip file: {archive!r}",
path=archive)
max_comment_start = max(file_size - MAX_COMMENT_LEN -
END_CENTRAL_DIR_SIZE, 0)
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
max_comment_start = max(file_size - MAX_COMMENT_LEN - END_CENTRAL_DIR_SIZE, 0)
try:
fp.seek(max_comment_start)
data = fp.read()
except OSError:
raise ZipImportError(f"can't read Zip file: {archive!r}",
path=archive)
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
pos = data.rfind(STRING_END_ARCHIVE)
if pos < 0:
raise ZipImportError(f'not a Zip file: {archive!r}',
path=archive)
buffer = data[pos:pos+END_CENTRAL_DIR_SIZE]
raise ZipImportError(f"not a Zip file: {archive!r}", path=archive)
buffer = data[pos : pos + END_CENTRAL_DIR_SIZE]
if len(buffer) != END_CENTRAL_DIR_SIZE:
raise ZipImportError(f"corrupt Zip file: {archive!r}",
path=archive)
raise ZipImportError(f"corrupt Zip file: {archive!r}", path=archive)
header_position = file_size - len(data) + pos
header_size = _unpack_uint32(buffer[12:16])
header_offset = _unpack_uint32(buffer[16:20])
if header_position < header_size:
raise ZipImportError(f'bad central directory size: {archive!r}', path=archive)
raise ZipImportError(f"bad central directory size: {archive!r}", path=archive)
if header_position < header_offset:
raise ZipImportError(f'bad central directory offset: {archive!r}', path=archive)
raise ZipImportError(f"bad central directory offset: {archive!r}", path=archive)
header_position -= header_size
arc_offset = header_position - header_offset
if arc_offset < 0:
raise ZipImportError(f'bad central directory size or offset: {archive!r}', path=archive)
raise ZipImportError(
f"bad central directory size or offset: {archive!r}", path=archive
)
files = {}
# Start of Central Directory
@ -435,12 +431,12 @@ def _read_directory(archive):
while True:
buffer = fp.read(46)
if len(buffer) < 4:
raise EOFError('EOF read where not expected')
raise EOFError("EOF read where not expected")
# Start of file header
if buffer[:4] != b'PK\x01\x02':
break # Bad: Central Dir File Header
if buffer[:4] != b"PK\x01\x02":
break # Bad: Central Dir File Header
if len(buffer) != 46:
raise EOFError('EOF read where not expected')
raise EOFError("EOF read where not expected")
flags = _unpack_uint16(buffer[8:10])
compress = _unpack_uint16(buffer[10:12])
time = _unpack_uint16(buffer[12:14])
@ -454,7 +450,7 @@ def _read_directory(archive):
file_offset = _unpack_uint32(buffer[42:46])
header_size = name_size + extra_size + comment_size
if file_offset > header_offset:
raise ZipImportError(f'bad local header offset: {archive!r}', path=archive)
raise ZipImportError(f"bad local header offset: {archive!r}", path=archive)
file_offset += arc_offset
try:
@ -478,18 +474,19 @@ def _read_directory(archive):
else:
# Historical ZIP filename encoding
try:
name = name.decode('ascii')
name = name.decode("ascii")
except UnicodeDecodeError:
name = name.decode('latin1').translate(cp437_table)
name = name.decode("latin1").translate(cp437_table)
name = name.replace('/', path_sep)
name = name.replace("/", path_sep)
path = _bootstrap_external._path_join(archive, name)
t = (path, compress, data_size, file_size, file_offset, time, date, crc)
files[name] = t
count += 1
_bootstrap._verbose_message('zipimport: found {} names in {!r}', count, archive)
_bootstrap._verbose_message("zipimport: found {} names in {!r}", count, archive)
return files
# During bootstrap, we may need to load the encodings
# package from a ZIP file. But the cp437 encoding is implemented
# in Python in the encodings package.
@ -498,31 +495,31 @@ def _read_directory(archive):
# the cp437 encoding.
cp437_table = (
# ASCII part, 8 rows x 16 chars
'\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f'
'\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f'
' !"#$%&\'()*+,-./'
'0123456789:;<=>?'
'@ABCDEFGHIJKLMNO'
'PQRSTUVWXYZ[\\]^_'
'`abcdefghijklmno'
'pqrstuvwxyz{|}~\x7f'
"\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f"
"\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f"
" !\"#$%&'()*+,-./"
"0123456789:;<=>?"
"@ABCDEFGHIJKLMNO"
"PQRSTUVWXYZ[\\]^_"
"`abcdefghijklmno"
"pqrstuvwxyz{|}~\x7f"
# non-ASCII part, 16 rows x 8 chars
'\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7'
'\xea\xeb\xe8\xef\xee\xec\xc4\xc5'
'\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9'
'\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192'
'\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba'
'\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb'
'\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556'
'\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510'
'\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f'
'\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567'
'\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b'
'\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580'
'\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4'
'\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229'
'\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248'
'\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0'
"\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7"
"\xea\xeb\xe8\xef\xee\xec\xc4\xc5"
"\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9"
"\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192"
"\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba"
"\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb"
"\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556"
"\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510"
"\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f"
"\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567"
"\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b"
"\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580"
"\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4"
"\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229"
"\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248"
"\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0"
)
_importing_zlib = False
@ -535,28 +532,29 @@ def _get_decompress_func():
if _importing_zlib:
# Someone has a zlib.py[co] in their Zip file
# let's avoid a stack overflow.
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
_bootstrap._verbose_message("zipimport: zlib UNAVAILABLE")
raise ZipImportError("can't decompress data; zlib not available")
_importing_zlib = True
try:
from zlib import decompress
except Exception:
_bootstrap._verbose_message('zipimport: zlib UNAVAILABLE')
_bootstrap._verbose_message("zipimport: zlib UNAVAILABLE")
raise ZipImportError("can't decompress data; zlib not available")
finally:
_importing_zlib = False
_bootstrap._verbose_message('zipimport: zlib available')
_bootstrap._verbose_message("zipimport: zlib available")
return decompress
# Given a path to a Zip file and a toc_entry, return the (uncompressed) data.
def _get_data(archive, toc_entry):
datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry
if data_size < 0:
raise ZipImportError('negative data size')
raise ZipImportError("negative data size")
with _io.open(archive, 'rb') as fp:
with _io.open(archive, "rb") as fp:
# Check to make sure the local file header is correct
try:
fp.seek(file_offset)
@ -564,11 +562,11 @@ def _get_data(archive, toc_entry):
raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive)
buffer = fp.read(30)
if len(buffer) != 30:
raise EOFError('EOF read where not expected')
raise EOFError("EOF read where not expected")
if buffer[:4] != b'PK\x03\x04':
if buffer[:4] != b"PK\x03\x04":
# Bad: Local File Header
raise ZipImportError(f'bad local file header: {archive!r}', path=archive)
raise ZipImportError(f"bad local file header: {archive!r}", path=archive)
name_size = _unpack_uint16(buffer[26:28])
extra_size = _unpack_uint16(buffer[28:30])
@ -601,16 +599,17 @@ def _eq_mtime(t1, t2):
# dostime only stores even seconds, so be lenient
return abs(t1 - t2) <= 1
# Given the contents of a .py[co] file, unmarshal the data
# and return the code object. Return None if it the magic word doesn't
# match (we do this instead of raising an exception as we fall back
# to .py if available and we don't want to mask other errors).
def _unmarshal_code(pathname, data, mtime):
if len(data) < 16:
raise ZipImportError('bad pyc data')
raise ZipImportError("bad pyc data")
if data[:4] != _bootstrap_external.MAGIC_NUMBER:
_bootstrap._verbose_message('{!r} has bad magic', pathname)
_bootstrap._verbose_message("{!r} has bad magic", pathname)
return None # signal caller to try alternative
flags = _unpack_uint32(data[4:8])
@ -619,47 +618,57 @@ def _unmarshal_code(pathname, data, mtime):
# pycs. We could validate hash-based pycs against the source, but it
# seems likely that most people putting hash-based pycs in a zipfile
# will use unchecked ones.
if (_imp.check_hash_based_pycs != 'never' and
(flags != 0x1 or _imp.check_hash_based_pycs == 'always')):
if _imp.check_hash_based_pycs != "never" and (
flags != 0x1 or _imp.check_hash_based_pycs == "always"
):
return None
elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime):
_bootstrap._verbose_message('{!r} has bad mtime', pathname)
_bootstrap._verbose_message("{!r} has bad mtime", pathname)
return None # signal caller to try alternative
# XXX the pyc's size field is ignored; timestamp collisions are probably
# unimportant with zip files.
code = marshal.loads(data[16:])
if not isinstance(code, _code_type):
raise TypeError(f'compiled module {pathname!r} is not a code object')
raise TypeError(f"compiled module {pathname!r} is not a code object")
return code
_code_type = type(_unmarshal_code.__code__)
# Replace any occurrences of '\r\n?' in the input string with '\n'.
# This converts DOS and Mac line endings to Unix line endings.
def _normalize_line_endings(source):
source = source.replace(b'\r\n', b'\n')
source = source.replace(b'\r', b'\n')
source = source.replace(b"\r\n", b"\n")
source = source.replace(b"\r", b"\n")
return source
# Given a string buffer containing Python source code, compile it
# and return a code object.
def _compile_source(pathname, source):
source = _normalize_line_endings(source)
return compile(source, pathname, 'exec', dont_inherit=True)
return compile(source, pathname, "exec", dont_inherit=True)
# Convert the date/time values found in the Zip archive to a value
# that's compatible with the time stamp stored in .pyc files.
def _parse_dostime(d, t):
return time.mktime((
(d >> 9) + 1980, # bits 9..15: year
(d >> 5) & 0xF, # bits 5..8: month
d & 0x1F, # bits 0..4: day
t >> 11, # bits 11..15: hours
(t >> 5) & 0x3F, # bits 8..10: minutes
(t & 0x1F) * 2, # bits 0..7: seconds / 2
-1, -1, -1))
return time.mktime(
(
(d >> 9) + 1980, # bits 9..15: year
(d >> 5) & 0xF, # bits 5..8: month
d & 0x1F, # bits 0..4: day
t >> 11, # bits 11..15: hours
(t >> 5) & 0x3F, # bits 8..10: minutes
(t & 0x1F) * 2, # bits 0..7: seconds / 2
-1,
-1,
-1,
)
)
# Given a path to a .pyc file in the archive, return the
# modification time of the matching .py file, or 0 if no source
@ -667,7 +676,7 @@ def _parse_dostime(d, t):
def _get_mtime_of_source(self, path):
try:
# strip 'c' or 'o' from *.py[co]
assert path[-1:] in ('c', 'o')
assert path[-1:] in ("c", "o")
path = path[:-1]
toc_entry = self._files[path]
# fetch the time stamp of the .py file for comparison
@ -678,13 +687,14 @@ def _get_mtime_of_source(self, path):
except (KeyError, IndexError, TypeError):
return 0
# Get the code object associated with the module specified by
# 'fullname'.
def _get_module_code(self, fullname):
path = _get_module_path(self, fullname)
for suffix, isbytecode, ispackage in _zip_searchorder:
fullpath = path + suffix
_bootstrap._verbose_message('trying {}{}{}', self.archive, path_sep, fullpath, verbosity=2)
_bootstrap._verbose_message("trying {}{}{}", self.archive, path_sep, fullpath, verbosity=2)
try:
toc_entry = self._files[fullpath]
except KeyError:
@ -713,6 +723,7 @@ class _ZipImportResourceReader:
This class is allowed to reference all the innards and private parts of
the zipimporter.
"""
_registered = False
def __init__(self, zipimporter, fullname):
@ -720,9 +731,10 @@ class _ZipImportResourceReader:
self.fullname = fullname
def open_resource(self, resource):
fullname_as_path = self.fullname.replace('.', '/')
path = f'{fullname_as_path}/{resource}'
fullname_as_path = self.fullname.replace(".", "/")
path = f"{fullname_as_path}/{resource}"
from io import BytesIO
try:
return BytesIO(self.zipimporter.get_data(path))
except OSError:
@ -737,8 +749,8 @@ class _ZipImportResourceReader:
def is_resource(self, name):
# Maybe we could do better, but if we can get the data, it's a
# resource. Otherwise it isn't.
fullname_as_path = self.fullname.replace('.', '/')
path = f'{fullname_as_path}/{name}'
fullname_as_path = self.fullname.replace(".", "/")
path = f"{fullname_as_path}/{name}"
try:
self.zipimporter.get_data(path)
except OSError:
@ -754,11 +766,12 @@ class _ZipImportResourceReader:
# top of the archive, and then we iterate through _files looking for
# names inside that "directory".
from pathlib import Path
fullname_path = Path(self.zipimporter.get_filename(self.fullname))
relative_path = fullname_path.relative_to(self.zipimporter.archive)
# Don't forget that fullname names a package, so its path will include
# __init__.py, which we want to ignore.
assert relative_path.name == '__init__.py'
assert relative_path.name == "__init__.py"
package_path = relative_path.parent
subdirs_seen = set()
for filename in self.zipimporter._files:

View file

@ -1,2 +1,2 @@
from .abc import BasePluginLoader, PluginLoader, PluginClass, IDConflictError, PluginMeta
from .zip import ZippedPluginLoader, MaubotZipImportError
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader, PluginMeta
from .zip import MaubotZipImportError, ZippedPluginLoader

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,14 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import TypeVar, Type, Dict, Set, List, TYPE_CHECKING
from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar
from abc import ABC, abstractmethod
import asyncio
from attr import dataclass
from packaging.version import Version, InvalidVersion
from packaging.version import InvalidVersion, Version
from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
from ..__meta__ import __version__
from ..plugin_base import Plugin
@ -89,16 +89,16 @@ class BasePluginLoader(ABC):
class PluginLoader(BasePluginLoader, ABC):
id_cache: Dict[str, 'PluginLoader'] = {}
id_cache: Dict[str, "PluginLoader"] = {}
meta: PluginMeta
references: Set['PluginInstance']
references: Set["PluginInstance"]
def __init__(self):
self.references = set()
@classmethod
def find(cls, plugin_id: str) -> 'PluginLoader':
def find(cls, plugin_id: str) -> "PluginLoader":
return cls.id_cache[plugin_id]
def to_dict(self) -> dict:
@ -109,12 +109,14 @@ class PluginLoader(BasePluginLoader, ABC):
}
async def stop_instances(self) -> None:
await asyncio.gather(*[instance.stop() for instance
in self.references if instance.started])
await asyncio.gather(
*[instance.stop() for instance in self.references if instance.started]
)
async def start_instances(self) -> None:
await asyncio.gather(*[instance.start() for instance
in self.references if instance.enabled])
await asyncio.gather(
*[instance.start() for instance in self.references if instance.enabled]
)
@abstractmethod
async def load(self) -> Type[PluginClass]:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,22 +13,23 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, List, Type, Tuple, Optional
from zipfile import ZipFile, BadZipFile
from time import time
import logging
import sys
import os
from __future__ import annotations
from time import time
from zipfile import BadZipFile, ZipFile
import logging
import os
import sys
from ruamel.yaml import YAML, YAMLError
from packaging.version import Version
from ruamel.yaml import YAML, YAMLError
from mautrix.types import SerializerError
from ..lib.zipimport import zipimporter, ZipImportError
from ..plugin_base import Plugin
from ..config import Config
from .abc import PluginLoader, PluginClass, PluginMeta, IDConflictError
from ..lib.zipimport import ZipImportError, zipimporter
from ..plugin_base import Plugin
from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta
yaml = YAML()
@ -50,23 +51,25 @@ class MaubotZipLoadError(MaubotZipImportError):
class ZippedPluginLoader(PluginLoader):
path_cache: Dict[str, 'ZippedPluginLoader'] = {}
path_cache: dict[str, ZippedPluginLoader] = {}
log: logging.Logger = logging.getLogger("maubot.loader.zip")
trash_path: str = "delete"
directories: List[str] = []
directories: list[str] = []
path: str
meta: PluginMeta
main_class: str
main_module: str
_loaded: Type[PluginClass]
_importer: zipimporter
_file: ZipFile
path: str | None
meta: PluginMeta | None
main_class: str | None
main_module: str | None
_loaded: type[PluginClass] | None
_importer: zipimporter | None
_file: ZipFile | None
def __init__(self, path: str) -> None:
super().__init__()
self.path = path
self.meta = None
self.main_class = None
self.main_module = None
self._loaded = None
self._importer = None
self._file = None
@ -75,7 +78,8 @@ class ZippedPluginLoader(PluginLoader):
try:
existing = self.id_cache[self.meta.id]
raise IDConflictError(
f"Plugin with id {self.meta.id} already loaded from {existing.source}")
f"Plugin with id {self.meta.id} already loaded from {existing.source}"
)
except KeyError:
pass
self.path_cache[self.path] = self
@ -83,13 +87,10 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Preloaded plugin {self.meta.id} from {self.path}")
def to_dict(self) -> dict:
return {
**super().to_dict(),
"path": self.path
}
return {**super().to_dict(), "path": self.path}
@classmethod
def get(cls, path: str) -> 'ZippedPluginLoader':
def get(cls, path: str) -> ZippedPluginLoader:
path = os.path.abspath(path)
try:
return cls.path_cache[path]
@ -101,10 +102,12 @@ class ZippedPluginLoader(PluginLoader):
return self.path
def __repr__(self) -> str:
return ("<ZippedPlugin "
f"path='{self.path}' "
f"meta={self.meta} "
f"loaded={self._loaded is not None}>")
return (
"<ZippedPlugin "
f"path='{self.path}' "
f"meta={self.meta} "
f"loaded={self._loaded is not None}>"
)
def sync_read_file(self, path: str) -> bytes:
return self._file.read(path)
@ -112,16 +115,19 @@ class ZippedPluginLoader(PluginLoader):
async def read_file(self, path: str) -> bytes:
return self.sync_read_file(path)
def sync_list_files(self, directory: str) -> List[str]:
def sync_list_files(self, directory: str) -> list[str]:
directory = directory.rstrip("/")
return [file.filename for file in self._file.filelist
if os.path.dirname(file.filename) == directory]
return [
file.filename
for file in self._file.filelist
if os.path.dirname(file.filename) == directory
]
async def list_files(self, directory: str) -> List[str]:
async def list_files(self, directory: str) -> list[str]:
return self.sync_list_files(directory)
@staticmethod
def _read_meta(source) -> Tuple[ZipFile, PluginMeta]:
def _read_meta(source) -> tuple[ZipFile, PluginMeta]:
try:
file = ZipFile(source)
data = file.read("maubot.yaml")
@ -142,7 +148,7 @@ class ZippedPluginLoader(PluginLoader):
return file, meta
@classmethod
def verify_meta(cls, source) -> Tuple[str, Version]:
def verify_meta(cls, source) -> tuple[str, Version]:
_, meta = cls._read_meta(source)
return meta.id, meta.version
@ -173,24 +179,24 @@ class ZippedPluginLoader(PluginLoader):
code = importer.get_code(self.main_module.replace(".", "/"))
if self.main_class not in code.co_names:
raise MaubotZipPreLoadError(
f"Main class {self.main_class} not in {self.main_module}")
f"Main class {self.main_class} not in {self.main_module}"
)
except ZipImportError as e:
raise MaubotZipPreLoadError(
f"Main module {self.main_module} not found in file") from e
raise MaubotZipPreLoadError(f"Main module {self.main_module} not found in file") from e
for module in self.meta.modules:
try:
importer.find_module(module)
except ZipImportError as e:
raise MaubotZipPreLoadError(f"Module {module} not found in file") from e
async def load(self, reset_cache: bool = False) -> Type[PluginClass]:
async def load(self, reset_cache: bool = False) -> type[PluginClass]:
try:
return self._load(reset_cache)
except MaubotZipImportError:
self.log.exception(f"Failed to load {self.meta.id} v{self.meta.version}")
raise
def _load(self, reset_cache: bool = False) -> Type[PluginClass]:
def _load(self, reset_cache: bool = False) -> type[PluginClass]:
if self._loaded is not None and not reset_cache:
return self._loaded
self._load_meta()
@ -219,7 +225,7 @@ class ZippedPluginLoader(PluginLoader):
self.log.debug(f"Loaded and imported plugin {self.meta.id} from {self.path}")
return plugin
async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]:
async def reload(self, new_path: str | None = None) -> type[PluginClass]:
await self.unload()
if new_path is not None:
self.path = new_path
@ -251,7 +257,7 @@ class ZippedPluginLoader(PluginLoader):
self.path = None
@classmethod
def trash(cls, file_path: str, new_name: Optional[str] = None, reason: str = "error") -> None:
def trash(cls, file_path: str, new_name: str | None = None, reason: str = "error") -> None:
if cls.trash_path == "delete":
os.remove(file_path)
else:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from asyncio import AbstractEventLoop
import importlib
from aiohttp import web
from ...config import Config
from .base import routes, get_config, set_config, set_loop
from .auth import check_token
from .base import get_config, routes, set_config, set_loop
from .middleware import auth, error
@ -30,9 +31,11 @@ def features(request: web.Request) -> web.Response:
if err is None:
return web.json_response(data)
else:
return web.json_response({
"login": data["login"],
})
return web.json_response(
{
"login": data["login"],
}
)
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,7 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from __future__ import annotations
from time import time
from aiohttp import web
@ -21,7 +22,7 @@ from aiohttp import web
from mautrix.types import UserID
from mautrix.util.signed_token import sign_token, verify_token
from .base import routes, get_config
from .base import get_config, routes
from .responses import resp
@ -33,10 +34,13 @@ def is_valid_token(token: str) -> bool:
def create_token(user: UserID) -> str:
return sign_token(get_config()["server.unshared_secret"], {
"user_id": user,
"created_at": int(time()),
})
return sign_token(
get_config()["server.unshared_secret"],
{
"user_id": user,
"created_at": int(time()),
},
)
def get_token(request: web.Request) -> str:
@ -44,11 +48,11 @@ def get_token(request: web.Request) -> str:
if not token or not token.startswith("Bearer "):
token = request.query.get("access_token", None)
else:
token = token[len("Bearer "):]
token = token[len("Bearer ") :]
return token
def check_token(request: web.Request) -> Optional[web.Response]:
def check_token(request: web.Request) -> web.Response | None:
token = get_token(request)
if not token:
return resp.no_token

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,15 +13,18 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web
from __future__ import annotations
import asyncio
from aiohttp import web
from ...__meta__ import __version__
from ...config import Config
routes: web.RouteTableDef = web.RouteTableDef()
_config: Config = None
_loop: asyncio.AbstractEventLoop = None
_config: Config | None = None
_loop: asyncio.AbstractEventLoop | None = None
def set_config(config: Config) -> None:
@ -44,6 +47,4 @@ def get_loop() -> asyncio.AbstractEventLoop:
@routes.get("/version")
async def version(_: web.Request) -> web.Response:
return web.json_response({
"version": __version__
})
return web.json_response({"version": __version__})

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,17 +13,18 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from __future__ import annotations
from json import JSONDecodeError
from aiohttp import web
from mautrix.types import UserID, SyncToken, FilterID
from mautrix.errors import MatrixRequestError, MatrixConnectionError, MatrixInvalidToken
from mautrix.client import Client as MatrixClient
from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequestError
from mautrix.types import FilterID, SyncToken, UserID
from ...db import DBClient
from ...client import Client
from ...db import DBClient
from .base import routes
from .responses import resp
@ -42,12 +43,17 @@ async def get_client(request: web.Request) -> web.Response:
return resp.found(client.to_dict())
async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
homeserver = data.get("homeserver", None)
access_token = data.get("access_token", None)
device_id = data.get("device_id", None)
new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token,
loop=Client.loop, client_session=Client.http_client)
new_client = MatrixClient(
mxid="@not:a.mxid",
base_url=homeserver,
token=access_token,
loop=Client.loop,
client_session=Client.http_client,
)
try:
whoami = await new_client.whoami()
except MatrixInvalidToken:
@ -64,13 +70,20 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
return resp.mxid_mismatch(whoami.user_id)
elif whoami.device_id and device_id and whoami.device_id != device_id:
return resp.device_id_mismatch(whoami.device_id)
db_instance = DBClient(id=whoami.user_id, homeserver=homeserver, access_token=access_token,
enabled=data.get("enabled", True), next_batch=SyncToken(""),
filter_id=FilterID(""), sync=data.get("sync", True),
autojoin=data.get("autojoin", True), online=data.get("online", True),
displayname=data.get("displayname", "disable"),
avatar_url=data.get("avatar_url", "disable"),
device_id=device_id)
db_instance = DBClient(
id=whoami.user_id,
homeserver=homeserver,
access_token=access_token,
enabled=data.get("enabled", True),
next_batch=SyncToken(""),
filter_id=FilterID(""),
sync=data.get("sync", True),
autojoin=data.get("autojoin", True),
online=data.get("online", True),
displayname=data.get("displayname", "disable"),
avatar_url=data.get("avatar_url", "disable"),
device_id=device_id,
)
client = Client(db_instance)
client.db_instance.insert()
await client.start()
@ -79,9 +92,11 @@ async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response:
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
try:
await client.update_access_details(data.get("access_token", None),
data.get("homeserver", None),
data.get("device_id", None))
await client.update_access_details(
data.get("access_token", None),
data.get("homeserver", None),
data.get("device_id", None),
)
except MatrixInvalidToken:
return resp.bad_client_access_token
except MatrixRequestError:
@ -91,9 +106,9 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
except ValueError as e:
str_err = str(e)
if str_err.startswith("MXID mismatch"):
return resp.mxid_mismatch(str(e)[len("MXID mismatch: "):])
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
elif str_err.startswith("Device ID mismatch"):
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: "):])
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
with client.db_instance.edit_mode():
await client.update_avatar_url(data.get("avatar_url", None))
await client.update_displayname(data.get("displayname", None))
@ -105,8 +120,9 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client(user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
async def _create_or_update_client(
user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
client = Client.get(user_id, None)
if not client:
return await _create_client(user_id, data)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,26 +13,26 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Tuple, NamedTuple, Optional
from json import JSONDecodeError
from typing import Dict, NamedTuple, Optional, Tuple
from http import HTTPStatus
import hashlib
from json import JSONDecodeError
import asyncio
import hashlib
import hmac
import random
import string
import hmac
from aiohttp import web
from yarl import URL
from mautrix.api import SynapseAdminPath, Method, Path
from mautrix.errors import MatrixRequestError
from mautrix.api import Method, Path, SynapseAdminPath
from mautrix.client import ClientAPI
from mautrix.types import LoginType, LoginResponse
from mautrix.errors import MatrixRequestError
from mautrix.types import LoginResponse, LoginType
from .base import routes, get_config, get_loop
from .base import get_config, get_loop, routes
from .client import _create_client, _create_or_update_client
from .responses import resp
from .client import _create_or_update_client, _create_client
def known_homeservers() -> Dict[str, Dict[str, str]]:
@ -59,8 +59,9 @@ class AuthRequestInfo(NamedTuple):
truthy_strings = ("1", "true", "yes")
async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo],
Optional[web.Response]]:
async def read_client_auth_request(
request: web.Request,
) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]:
server_name = request.match_info.get("server", None)
server = known_homeservers().get(server_name, None)
if not server:
@ -81,21 +82,30 @@ async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthR
base_url = server["url"]
except KeyError:
return None, resp.invalid_server
return AuthRequestInfo(
server_name=server_name,
client=ClientAPI(base_url=base_url, loop=get_loop()),
secret=server.get("secret"),
username=username,
password=password,
user_type=body.get("user_type", "bot"),
device_name=body.get("device_name", "Maubot"),
update_client=request.query.get("update_client", "").lower() in truthy_strings,
sso=sso,
), None
return (
AuthRequestInfo(
server_name=server_name,
client=ClientAPI(base_url=base_url, loop=get_loop()),
secret=server.get("secret"),
username=username,
password=password,
user_type=body.get("user_type", "bot"),
device_name=body.get("device_name", "Maubot"),
update_client=request.query.get("update_client", "").lower() in truthy_strings,
sso=sso,
),
None,
)
def generate_mac(secret: str, nonce: str, username: str, password: str, admin: bool = False,
user_type: str = None) -> str:
def generate_mac(
secret: str,
nonce: str,
username: str,
password: str,
admin: bool = False,
user_type: str = None,
) -> str:
mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1)
mac.update(nonce.encode("utf-8"))
mac.update(b"\x00")
@ -132,18 +142,24 @@ async def register(request: web.Request) -> web.Response:
try:
raw_res = await req.client.api.request(Method.POST, path, content=content)
except MatrixRequestError as e:
return web.json_response({
"errcode": e.errcode,
"error": e.message,
"http_status": e.http_status,
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
return web.json_response(
{
"errcode": e.errcode,
"error": e.message,
"http_status": e.http_status,
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
login_res = LoginResponse.deserialize(raw_res)
if req.update_client:
return await _create_client(login_res.user_id, {
"homeserver": str(req.client.api.base_url),
"access_token": login_res.access_token,
"device_id": login_res.device_id,
})
return await _create_client(
login_res.user_id,
{
"homeserver": str(req.client.api.base_url),
"access_token": login_res.access_token,
"device_id": login_res.device_id,
},
)
return web.json_response(login_res.serialize())
@ -162,13 +178,17 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response:
flows = await req.client.get_login_flows()
if not flows.supports_type(LoginType.SSO):
return resp.sso_not_supported
waiter_id = ''.join(random.choices(string.ascii_lowercase + string.digits, k=16))
waiter_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=16))
cfg = get_config()
public_url = (URL(cfg["server.public_url"]) / cfg["server.base_path"].lstrip("/")
/ "client/auth_external_sso/complete" / waiter_id)
sso_url = (req.client.api.base_url
.with_path(str(Path.login.sso.redirect))
.with_query({"redirectUrl": str(public_url)}))
public_url = (
URL(cfg["server.public_url"])
/ cfg["server.base_path"].lstrip("/")
/ "client/auth_external_sso/complete"
/ waiter_id
)
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
{"redirectUrl": str(public_url)}
)
sso_waiters[waiter_id] = req, get_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
@ -178,25 +198,40 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) ->
device_id = f"maubot_{device_id}"
try:
if req.sso:
res = await req.client.login(token=login_token, login_type=LoginType.TOKEN,
device_id=device_id, store_access_token=False,
initial_device_display_name=req.device_name)
res = await req.client.login(
token=login_token,
login_type=LoginType.TOKEN,
device_id=device_id,
store_access_token=False,
initial_device_display_name=req.device_name,
)
else:
res = await req.client.login(identifier=req.username, login_type=LoginType.PASSWORD,
password=req.password, device_id=device_id,
initial_device_display_name=req.device_name,
store_access_token=False)
res = await req.client.login(
identifier=req.username,
login_type=LoginType.PASSWORD,
password=req.password,
device_id=device_id,
initial_device_display_name=req.device_name,
store_access_token=False,
)
except MatrixRequestError as e:
return web.json_response({
"errcode": e.errcode,
"error": e.message,
}, status=e.http_status)
return web.json_response(
{
"errcode": e.errcode,
"error": e.message,
},
status=e.http_status,
)
if req.update_client:
return await _create_or_update_client(res.user_id, {
"homeserver": str(req.client.api.base_url),
"access_token": res.access_token,
"device_id": res.device_id,
}, is_login=True)
return await _create_or_update_client(
res.user_id,
{
"homeserver": str(req.client.api.base_url),
"access_token": res.access_token,
"device_id": res.device_id,
},
is_login=True,
)
return web.json_response(res.serialize())
@ -230,6 +265,8 @@ async def complete_sso(request: web.Request) -> web.Response:
return web.Response(status=400, text="Missing loginToken query parameter\n")
except asyncio.InvalidStateError:
return web.Response(status=500, text="Invalid state\n")
return web.Response(status=200,
text="Login token received, please return to your Maubot client. "
"This tab can be closed.\n")
return web.Response(
status=200,
text="Login token received, please return to your Maubot client. "
"This tab can be closed.\n",
)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,7 +13,7 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from aiohttp import web, client as http
from aiohttp import client as http, web
from ...client import Client
from .base import routes
@ -45,8 +45,9 @@ async def proxy(request: web.Request) -> web.StreamResponse:
headers["X-Forwarded-For"] = f"{host}:{port}"
data = await request.read()
async with http.request(request.method, f"{client.homeserver}/{path}", headers=headers,
params=query, data=data) as proxy_resp:
async with http.request(
request.method, f"{client.homeserver}/{path}", headers=headers, params=query, data=data
) as proxy_resp:
response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers)
await response.prepare(request)
async for chunk in proxy_resp.content.iter_chunked(PROXY_CHUNK_SIZE):

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -14,11 +14,11 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from string import Template
from subprocess import run
import asyncio
import re
from ruamel.yaml import YAML
from aiohttp import web
from ruamel.yaml import YAML
from .base import routes
@ -27,9 +27,7 @@ enabled = False
@routes.get("/debug/open")
async def check_enabled(_: web.Request) -> web.Response:
return web.json_response({
"enabled": enabled,
})
return web.json_response({"enabled": enabled})
try:
@ -40,7 +38,6 @@ try:
editor_command = Template(cfg["editor"])
pathmap = [(re.compile(item["find"]), item["replace"]) for item in cfg["pathmap"]]
@routes.post("/debug/open")
async def open_file(request: web.Request) -> web.Response:
data = await request.json()
@ -51,13 +48,9 @@ try:
cmd = editor_command.substitute(path=path, line=data["line"])
except (KeyError, ValueError):
return web.Response(status=400)
res = run(cmd, shell=True)
return web.json_response({
"return": res.returncode,
"stdout": res.stdout,
"stderr": res.stderr
})
res = await asyncio.create_subprocess_shell(cmd)
stdout, stderr = await res.communicate()
return web.json_response({"return": res.returncode, "stdout": stdout, "stderr": stderr})
enabled = True
except Exception:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,10 +17,10 @@ from json import JSONDecodeError
from aiohttp import web
from ...client import Client
from ...db import DBPlugin
from ...instance import PluginInstance
from ...loader import PluginLoader
from ...client import Client
from .base import routes
from .responses import resp
@ -52,8 +52,13 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
PluginLoader.find(plugin_type)
except KeyError:
return resp.plugin_type_not_found
db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True),
primary_user=primary_user, config=data.get("config", ""))
db_instance = DBPlugin(
id=instance_id,
type=plugin_type,
enabled=data.get("enabled", True),
primary_user=primary_user,
config=data.get("config", ""),
)
instance = PluginInstance(db_instance)
instance.load()
instance.db_instance.insert()

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,13 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, TYPE_CHECKING
from __future__ import annotations
from datetime import datetime
from aiohttp import web
from sqlalchemy import Table, Column, asc, desc, exc
from sqlalchemy.orm import Query
from sqlalchemy import Column, Table, asc, desc, exc
from sqlalchemy.engine.result import ResultProxy, RowProxy
from sqlalchemy.orm import Query
from ...instance import PluginInstance
from .base import routes
@ -34,23 +35,26 @@ async def get_database(request: web.Request) -> web.Response:
return resp.instance_not_found
elif not instance.inst_db:
return resp.plugin_has_no_database
if TYPE_CHECKING:
table: Table
column: Column
return web.json_response({
table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
"autoincrement": column.autoincrement,
} for column in table.columns
},
} for table in instance.get_db_tables().values()
})
table: Table
column: Column
return web.json_response(
{
table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
"autoincrement": column.autoincrement,
}
for column in table.columns
},
}
for table in instance.get_db_tables().values()
}
)
def check_type(val):
@ -74,9 +78,12 @@ async def get_table(request: web.Request) -> web.Response:
return resp.table_not_found
try:
order = [tuple(order.split(":")) for order in request.query.getall("order")]
order = [(asc if sort.lower() == "asc" else desc)(table.columns[column])
if sort else table.columns[column]
for column, sort in order]
order = [
(asc if sort.lower() == "asc" else desc)(table.columns[column])
if sort
else table.columns[column]
for column, sort in order
]
except KeyError:
order = []
limit = int(request.query.get("limit", 100))
@ -96,12 +103,12 @@ async def query(request: web.Request) -> web.Response:
sql_query = data["query"]
except KeyError:
return resp.query_missing
return execute_query(instance, sql_query,
rows_as_dict=data.get("rows_as_dict", False))
return execute_query(instance, sql_query, rows_as_dict=data.get("rows_as_dict", False))
def execute_query(instance: PluginInstance, sql_query: Union[str, Query],
rows_as_dict: bool = False) -> web.Response:
def execute_query(
instance: PluginInstance, sql_query: str | Query, rows_as_dict: bool = False
) -> web.Response:
try:
res: ResultProxy = instance.inst_db.execute(sql_query)
except exc.IntegrityError as e:
@ -114,10 +121,14 @@ def execute_query(instance: PluginInstance, sql_query: Union[str, Query],
}
if res.returns_rows:
row: RowProxy
data["rows"] = [({key: check_type(value) for key, value in row.items()}
if rows_as_dict
else [check_type(value) for value in row])
for row in res]
data["rows"] = [
(
{key: check_type(value) for key, value in row.items()}
if rows_as_dict
else [check_type(value) for value in row]
)
for row in res
]
data["columns"] = res.keys()
else:
data["rowcount"] = res.rowcount

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,31 +13,60 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Deque, List
from datetime import datetime
from __future__ import annotations
from collections import deque
import logging
from datetime import datetime
import asyncio
import logging
from aiohttp import web
from aiohttp import web, web_ws
from .base import routes, get_loop
from .auth import is_valid_token
from .base import get_loop, routes
BUILTIN_ATTRS = {"args", "asctime", "created", "exc_info", "exc_text", "filename", "funcName",
"levelname", "levelno", "lineno", "module", "msecs", "message", "msg", "name",
"pathname", "process", "processName", "relativeCreated", "stack_info", "thread",
"threadName"}
INCLUDE_ATTRS = {"filename", "funcName", "levelname", "levelno", "lineno", "module", "name",
"pathname"}
BUILTIN_ATTRS = {
"args",
"asctime",
"created",
"exc_info",
"exc_text",
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"msecs",
"message",
"msg",
"name",
"pathname",
"process",
"processName",
"relativeCreated",
"stack_info",
"thread",
"threadName",
}
INCLUDE_ATTRS = {
"filename",
"funcName",
"levelname",
"levelno",
"lineno",
"module",
"name",
"pathname",
}
EXCLUDE_ATTRS = BUILTIN_ATTRS - INCLUDE_ATTRS
MAX_LINES = 2048
class LogCollector(logging.Handler):
lines: Deque[dict]
lines: deque[dict]
formatter: logging.Formatter
listeners: List[web.WebSocketResponse]
listeners: list[web.WebSocketResponse]
loop: asyncio.AbstractEventLoop
def __init__(self, level=logging.NOTSET) -> None:
@ -56,9 +85,7 @@ class LogCollector(logging.Handler):
# JSON conversion based on Marsel Mavletkulov's json-log-formatter (MIT license)
# https://github.com/marselester/json-log-formatter
content = {
name: value
for name, value in record.__dict__.items()
if name not in EXCLUDE_ATTRS
name: value for name, value in record.__dict__.items() if name not in EXCLUDE_ATTRS
}
content["id"] = str(record.relativeCreated)
content["msg"] = record.getMessage()
@ -119,6 +146,7 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
asyncio.ensure_future(close_if_not_authenticated())
try:
msg: web_ws.WSMessage
async for msg in ws:
if msg.type != web.WSMsgType.TEXT:
continue

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,9 +17,10 @@ import json
from aiohttp import web
from .base import routes, get_config
from .responses import resp
from .auth import create_token
from .base import get_config, routes
from .responses import resp
@routes.post("/auth/login")
async def login(request: web.Request) -> web.Response:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,15 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Callable, Awaitable
import logging
from typing import Awaitable, Callable
import base64
import logging
from aiohttp import web
from .responses import resp
from .auth import check_token
from .base import get_config
from .responses import resp
Handler = Callable[[web.Request], Awaitable[web.Response]]
log = logging.getLogger("maubot.server")
@ -29,7 +29,7 @@ log = logging.getLogger("maubot.server")
@web.middleware
async def auth(request: web.Request, handler: Handler) -> web.Response:
subpath = request.path[len(get_config()["server.base_path"]):]
subpath = request.path[len(get_config()["server.base_path"]) :]
if (
subpath.startswith("/auth/")
or subpath.startswith("/client/auth_external_sso/complete/")
@ -52,15 +52,18 @@ async def error(request: web.Request, handler: Handler) -> web.Response:
return resp.path_not_found
elif ex.status_code == 405:
return resp.method_not_allowed
return web.json_response({
"httpexception": {
"headers": {key: value for key, value in ex.headers.items()},
"class": type(ex).__name__,
"body": ex.text or base64.b64encode(ex.body)
return web.json_response(
{
"httpexception": {
"headers": {key: value for key, value in ex.headers.items()},
"class": type(ex).__name__,
"body": ex.text or base64.b64encode(ex.body),
},
"error": f"Unhandled HTTP {ex.status}: {ex.text[:128] or 'non-text response'}",
"errcode": f"unhandled_http_{ex.status}",
},
"error": f"Unhandled HTTP {ex.status}: {ex.text[:128] or 'non-text response'}",
"errcode": f"unhandled_http_{ex.status}",
}, status=ex.status)
status=ex.status,
)
except Exception:
log.exception("Error in handler")
return resp.internal_server_error

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -17,9 +17,9 @@ import traceback
from aiohttp import web
from ...loader import PluginLoader, MaubotZipImportError
from .responses import resp
from ...loader import MaubotZipImportError, PluginLoader
from .base import routes
from .responses import resp
@routes.get("/plugins")

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -15,16 +15,16 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from io import BytesIO
from time import time
import traceback
import os.path
import re
import traceback
from aiohttp import web
from packaging.version import Version
from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError
from ...loader import MaubotZipImportError, PluginLoader, ZippedPluginLoader
from .base import get_config, routes
from .responses import resp
from .base import routes, get_config
@routes.put("/plugin/{id}")
@ -78,15 +78,20 @@ async def upload_new_plugin(content: bytes, pid: str, version: Version) -> web.R
return resp.created(plugin.to_dict())
async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes,
new_version: Version) -> web.Response:
async def upload_replacement_plugin(
plugin: ZippedPluginLoader, content: bytes, new_version: Version
) -> web.Response:
dirname = os.path.dirname(plugin.path)
old_filename = os.path.basename(plugin.path)
if str(plugin.meta.version) in old_filename:
replacement = (str(new_version) if plugin.meta.version != new_version
else f"{new_version}-ts{int(time())}")
filename = re.sub(f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?",
replacement, old_filename)
replacement = (
str(new_version)
if plugin.meta.version != new_version
else f"{new_version}-ts{int(time())}"
)
filename = re.sub(
f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", replacement, old_filename
)
else:
filename = old_filename.rstrip(".mbp")
filename = f"{filename}-v{new_version}.mbp"

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -16,296 +16,416 @@
from http import HTTPStatus
from aiohttp import web
from sqlalchemy.exc import OperationalError, IntegrityError
from sqlalchemy.exc import IntegrityError, OperationalError
class _Response:
@property
def body_not_json(self) -> web.Response:
return web.json_response({
"error": "Request body is not JSON",
"errcode": "body_not_json",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Request body is not JSON",
"errcode": "body_not_json",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def plugin_type_required(self) -> web.Response:
return web.json_response({
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Plugin type is required when creating plugin instances",
"errcode": "plugin_type_required",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def primary_user_required(self) -> web.Response:
return web.json_response({
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Primary user is required when creating plugin instances",
"errcode": "primary_user_required",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def bad_client_access_token(self) -> web.Response:
return web.json_response({
"error": "Invalid access token",
"errcode": "bad_client_access_token",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Invalid access token",
"errcode": "bad_client_access_token",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def bad_client_access_details(self) -> web.Response:
return web.json_response({
"error": "Invalid homeserver or access token",
"errcode": "bad_client_access_details"
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Invalid homeserver or access token",
"errcode": "bad_client_access_details",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def bad_client_connection_details(self) -> web.Response:
return web.json_response({
"error": "Could not connect to homeserver",
"errcode": "bad_client_connection_details"
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Could not connect to homeserver",
"errcode": "bad_client_connection_details",
},
status=HTTPStatus.BAD_REQUEST,
)
def mxid_mismatch(self, found: str) -> web.Response:
return web.json_response({
"error": "The Matrix user ID of the client and the user ID of the access token don't "
f"match. Access token is for user {found}",
"errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": (
"The Matrix user ID of the client and the user ID of the access token don't "
f"match. Access token is for user {found}"
),
"errcode": "mxid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
def device_id_mismatch(self, found: str) -> web.Response:
return web.json_response({
"error": "The Matrix device ID of the client and the device ID of the access token "
f"don't match. Access token is for device {found}",
"errcode": "mxid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": (
"The Matrix device ID of the client and the device ID of the access token "
f"don't match. Access token is for device {found}"
),
"errcode": "mxid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def pid_mismatch(self) -> web.Response:
return web.json_response({
"error": "The ID in the path does not match the ID of the uploaded plugin",
"errcode": "pid_mismatch",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "The ID in the path does not match the ID of the uploaded plugin",
"errcode": "pid_mismatch",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def username_or_password_missing(self) -> web.Response:
return web.json_response({
"error": "Username or password missing",
"errcode": "username_or_password_missing",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Username or password missing",
"errcode": "username_or_password_missing",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def query_missing(self) -> web.Response:
return web.json_response({
"error": "Query missing",
"errcode": "query_missing",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Query missing",
"errcode": "query_missing",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod
def sql_operational_error(error: OperationalError, query: str) -> web.Response:
return web.json_response({
"ok": False,
"query": query,
"error": str(error.orig),
"full_error": str(error),
"errcode": "sql_operational_error",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"ok": False,
"query": query,
"error": str(error.orig),
"full_error": str(error),
"errcode": "sql_operational_error",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod
def sql_integrity_error(error: IntegrityError, query: str) -> web.Response:
return web.json_response({
"ok": False,
"query": query,
"error": str(error.orig),
"full_error": str(error),
"errcode": "sql_integrity_error",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"ok": False,
"query": query,
"error": str(error.orig),
"full_error": str(error),
"errcode": "sql_integrity_error",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def bad_auth(self) -> web.Response:
return web.json_response({
"error": "Invalid username or password",
"errcode": "invalid_auth",
}, status=HTTPStatus.UNAUTHORIZED)
return web.json_response(
{
"error": "Invalid username or password",
"errcode": "invalid_auth",
},
status=HTTPStatus.UNAUTHORIZED,
)
@property
def no_token(self) -> web.Response:
return web.json_response({
"error": "Authorization token missing",
"errcode": "auth_token_missing",
}, status=HTTPStatus.UNAUTHORIZED)
return web.json_response(
{
"error": "Authorization token missing",
"errcode": "auth_token_missing",
},
status=HTTPStatus.UNAUTHORIZED,
)
@property
def invalid_token(self) -> web.Response:
return web.json_response({
"error": "Invalid authorization token",
"errcode": "auth_token_invalid",
}, status=HTTPStatus.UNAUTHORIZED)
return web.json_response(
{
"error": "Invalid authorization token",
"errcode": "auth_token_invalid",
},
status=HTTPStatus.UNAUTHORIZED,
)
@property
def plugin_not_found(self) -> web.Response:
return web.json_response({
"error": "Plugin not found",
"errcode": "plugin_not_found",
}, status=HTTPStatus.NOT_FOUND)
return web.json_response(
{
"error": "Plugin not found",
"errcode": "plugin_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def client_not_found(self) -> web.Response:
return web.json_response({
"error": "Client not found",
"errcode": "client_not_found",
}, status=HTTPStatus.NOT_FOUND)
return web.json_response(
{
"error": "Client not found",
"errcode": "client_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def primary_user_not_found(self) -> web.Response:
return web.json_response({
"error": "Client for given primary user not found",
"errcode": "primary_user_not_found",
}, status=HTTPStatus.NOT_FOUND)
return web.json_response(
{
"error": "Client for given primary user not found",
"errcode": "primary_user_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def instance_not_found(self) -> web.Response:
return web.json_response({
"error": "Plugin instance not found",
"errcode": "instance_not_found",
}, status=HTTPStatus.NOT_FOUND)
return web.json_response(
{
"error": "Plugin instance not found",
"errcode": "instance_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def plugin_type_not_found(self) -> web.Response:
return web.json_response({
"error": "Given plugin type not found",
"errcode": "plugin_type_not_found",
}, status=HTTPStatus.NOT_FOUND)
return web.json_response(
{
"error": "Given plugin type not found",
"errcode": "plugin_type_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def path_not_found(self) -> web.Response:
return web.json_response({
"error": "Resource not found",
"errcode": "resource_not_found",
}, status=HTTPStatus.NOT_FOUND)
return web.json_response(
{
"error": "Resource not found",
"errcode": "resource_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def server_not_found(self) -> web.Response:
return web.json_response({
"error": "Registration target server not found",
"errcode": "server_not_found",
}, status=HTTPStatus.NOT_FOUND)
return web.json_response(
{
"error": "Registration target server not found",
"errcode": "server_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def registration_secret_not_found(self) -> web.Response:
return web.json_response({
"error": "Config does not have a registration secret for that server",
"errcode": "registration_secret_not_found",
}, status=HTTPStatus.NOT_FOUND)
return web.json_response(
{
"error": "Config does not have a registration secret for that server",
"errcode": "registration_secret_not_found",
},
status=HTTPStatus.NOT_FOUND,
)
@property
def registration_no_sso(self) -> web.Response:
return web.json_response({
"error": "The register operation is only for registering with a password",
"errcode": "registration_no_sso",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "The register operation is only for registering with a password",
"errcode": "registration_no_sso",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def sso_not_supported(self) -> web.Response:
return web.json_response({
"error": "That server does not seem to support single sign-on",
"errcode": "sso_not_supported",
}, status=HTTPStatus.FORBIDDEN)
return web.json_response(
{
"error": "That server does not seem to support single sign-on",
"errcode": "sso_not_supported",
},
status=HTTPStatus.FORBIDDEN,
)
@property
def plugin_has_no_database(self) -> web.Response:
return web.json_response({
"error": "Given plugin does not have a database",
"errcode": "plugin_has_no_database",
})
return web.json_response(
{
"error": "Given plugin does not have a database",
"errcode": "plugin_has_no_database",
}
)
@property
def table_not_found(self) -> web.Response:
return web.json_response({
"error": "Given table not found in plugin database",
"errcode": "table_not_found",
})
return web.json_response(
{
"error": "Given table not found in plugin database",
"errcode": "table_not_found",
}
)
@property
def method_not_allowed(self) -> web.Response:
return web.json_response({
"error": "Method not allowed",
"errcode": "method_not_allowed",
}, status=HTTPStatus.METHOD_NOT_ALLOWED)
return web.json_response(
{
"error": "Method not allowed",
"errcode": "method_not_allowed",
},
status=HTTPStatus.METHOD_NOT_ALLOWED,
)
@property
def user_exists(self) -> web.Response:
return web.json_response({
"error": "There is already a client with the user ID of that token",
"errcode": "user_exists",
}, status=HTTPStatus.CONFLICT)
return web.json_response(
{
"error": "There is already a client with the user ID of that token",
"errcode": "user_exists",
},
status=HTTPStatus.CONFLICT,
)
@property
def plugin_exists(self) -> web.Response:
return web.json_response({
"error": "A plugin with the same ID as the uploaded plugin already exists",
"errcode": "plugin_exists"
}, status=HTTPStatus.CONFLICT)
return web.json_response(
{
"error": "A plugin with the same ID as the uploaded plugin already exists",
"errcode": "plugin_exists",
},
status=HTTPStatus.CONFLICT,
)
@property
def plugin_in_use(self) -> web.Response:
return web.json_response({
"error": "Plugin instances of this type still exist",
"errcode": "plugin_in_use",
}, status=HTTPStatus.PRECONDITION_FAILED)
return web.json_response(
{
"error": "Plugin instances of this type still exist",
"errcode": "plugin_in_use",
},
status=HTTPStatus.PRECONDITION_FAILED,
)
@property
def client_in_use(self) -> web.Response:
return web.json_response({
"error": "Plugin instances with this client as their primary user still exist",
"errcode": "client_in_use",
}, status=HTTPStatus.PRECONDITION_FAILED)
return web.json_response(
{
"error": "Plugin instances with this client as their primary user still exist",
"errcode": "client_in_use",
},
status=HTTPStatus.PRECONDITION_FAILED,
)
@staticmethod
def plugin_import_error(error: str, stacktrace: str) -> web.Response:
return web.json_response({
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_invalid",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_invalid",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod
def plugin_reload_error(error: str, stacktrace: str) -> web.Response:
return web.json_response({
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_reload_fail",
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
return web.json_response(
{
"error": error,
"stacktrace": stacktrace,
"errcode": "plugin_reload_fail",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property
def internal_server_error(self) -> web.Response:
return web.json_response({
"error": "Internal server error",
"errcode": "internal_server_error",
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
return web.json_response(
{
"error": "Internal server error",
"errcode": "internal_server_error",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property
def invalid_server(self) -> web.Response:
return web.json_response({
"error": "Invalid registration server object in maubot configuration",
"errcode": "invalid_server",
}, status=HTTPStatus.INTERNAL_SERVER_ERROR)
return web.json_response(
{
"error": "Invalid registration server object in maubot configuration",
"errcode": "invalid_server",
},
status=HTTPStatus.INTERNAL_SERVER_ERROR,
)
@property
def unsupported_plugin_loader(self) -> web.Response:
return web.json_response({
"error": "Existing plugin with same ID uses unsupported plugin loader",
"errcode": "unsupported_plugin_loader",
}, status=HTTPStatus.BAD_REQUEST)
return web.json_response(
{
"error": "Existing plugin with same ID uses unsupported plugin loader",
"errcode": "unsupported_plugin_loader",
},
status=HTTPStatus.BAD_REQUEST,
)
@property
def not_implemented(self) -> web.Response:
return web.json_response({
"error": "Not implemented",
"errcode": "not_implemented",
}, status=HTTPStatus.NOT_IMPLEMENTED)
return web.json_response(
{
"error": "Not implemented",
"errcode": "not_implemented",
},
status=HTTPStatus.NOT_IMPLEMENTED,
)
@property
def ok(self) -> web.Response:
return web.json_response({
"success": True,
}, status=HTTPStatus.OK)
return web.json_response(
{"success": True},
status=HTTPStatus.OK,
)
@property
def deleted(self) -> web.Response:
@ -320,15 +440,10 @@ class _Response:
return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK)
def logged_in(self, token: str) -> web.Response:
return self.found({
"token": token,
})
return self.found({"token": token})
def pong(self, user: str, features: dict) -> web.Response:
return self.found({
"username": user,
"features": features,
})
return self.found({"username": user, "features": features})
@staticmethod
def created(data: dict) -> web.Response:

View file

@ -1,6 +1,6 @@
<!--
maubot - A plugin-based Matrix bot system.
Copyright (C) 2019 Tulir Asokan
Copyright (C) 2022 Tulir Asokan
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2021 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
// maubot - A plugin-based Matrix bot system.
// Copyright (C) 2019 Tulir Asokan
// Copyright (C) 2022 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,23 +13,36 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Union, Awaitable, Optional, Tuple, List
from __future__ import annotations
from typing import Awaitable
from html import escape
import asyncio
import attr
from mautrix.client import Client as MatrixClient, SyncStream
from mautrix.util.formatter import MatrixParser, MarkdownString, EntityType
from mautrix.util import markdown
from mautrix.types import (EventType, MessageEvent, Event, EventID, RoomID, MessageEventContent,
MessageType, TextMessageEventContent, Format, RelatesTo, EncryptedEvent)
from mautrix.errors import DecryptionError
from mautrix.types import (
EncryptedEvent,
Event,
EventID,
EventType,
Format,
MessageEvent,
MessageEventContent,
MessageType,
RelatesTo,
RoomID,
TextMessageEventContent,
)
from mautrix.util import markdown
from mautrix.util.formatter import EntityType, MarkdownString, MatrixParser
class HumanReadableString(MarkdownString):
def format(self, entity_type: EntityType, **kwargs) -> 'MarkdownString':
if entity_type == EntityType.URL and kwargs['url'] != self.text:
def format(self, entity_type: EntityType, **kwargs) -> MarkdownString:
if entity_type == EntityType.URL and kwargs["url"] != self.text:
self.text = f"{self.text} ({kwargs['url']})"
return self
return super(HumanReadableString, self).format(entity_type, **kwargs)
@ -39,8 +52,9 @@ class MaubotHTMLParser(MatrixParser[HumanReadableString]):
fs = HumanReadableString
async def parse_formatted(message: str, allow_html: bool = False, render_markdown: bool = True
) -> Tuple[str, str]:
async def parse_formatted(
message: str, allow_html: bool = False, render_markdown: bool = True
) -> tuple[str, str]:
if render_markdown:
html = markdown.render(message, allow_html=allow_html)
elif allow_html:
@ -51,19 +65,25 @@ async def parse_formatted(message: str, allow_html: bool = False, render_markdow
class MaubotMessageEvent(MessageEvent):
client: 'MaubotMatrixClient'
client: MaubotMatrixClient
disable_reply: bool
def __init__(self, base: MessageEvent, client: 'MaubotMatrixClient'):
super().__init__(**{a.name.lstrip("_"): getattr(base, a.name)
for a in attr.fields(MessageEvent)})
def __init__(self, base: MessageEvent, client: MaubotMatrixClient):
super().__init__(
**{a.name.lstrip("_"): getattr(base, a.name) for a in attr.fields(MessageEvent)}
)
self.client = client
self.disable_reply = client.disable_replies
async def respond(self, content: Union[str, MessageEventContent],
event_type: EventType = EventType.ROOM_MESSAGE, markdown: bool = True,
allow_html: bool = False, reply: Union[bool, str] = False,
edits: Optional[Union[EventID, MessageEvent]] = None) -> EventID:
async def respond(
self,
content: str | MessageEventContent,
event_type: EventType = EventType.ROOM_MESSAGE,
markdown: bool = True,
allow_html: bool = False,
reply: bool | str = False,
edits: EventID | MessageEvent | None = None,
) -> EventID:
if isinstance(content, str):
content = TextMessageEventContent(msgtype=MessageType.NOTICE, body=content)
if allow_html or markdown:
@ -77,18 +97,25 @@ class MaubotMessageEvent(MessageEvent):
if reply != "force" and self.disable_reply:
content.body = f"{self.sender}: {content.body}"
fmt_body = content.formatted_body or escape(content.body).replace("\n", "<br>")
content.formatted_body = (f'<a href="https://matrix.to/#/{self.sender}">'
f'{self.sender}'
f'</a>: {fmt_body}')
content.formatted_body = (
f'<a href="https://matrix.to/#/{self.sender}">'
f"{self.sender}"
f"</a>: {fmt_body}"
)
else:
content.set_reply(self)
return await self.client.send_message_event(self.room_id, event_type, content)
def reply(self, content: Union[str, MessageEventContent],
event_type: EventType = EventType.ROOM_MESSAGE, markdown: bool = True,
allow_html: bool = False) -> Awaitable[EventID]:
return self.respond(content, event_type, markdown=markdown, reply=True,
allow_html=allow_html)
def reply(
self,
content: str | MessageEventContent,
event_type: EventType = EventType.ROOM_MESSAGE,
markdown: bool = True,
allow_html: bool = False,
) -> Awaitable[EventID]:
return self.respond(
content, event_type, markdown=markdown, reply=True, allow_html=allow_html
)
def mark_read(self) -> Awaitable[None]:
return self.client.send_receipt(self.room_id, self.event_id, "m.read")
@ -96,11 +123,16 @@ class MaubotMessageEvent(MessageEvent):
def react(self, key: str) -> Awaitable[EventID]:
return self.client.react(self.room_id, self.event_id, key)
def edit(self, content: Union[str, MessageEventContent],
event_type: EventType = EventType.ROOM_MESSAGE, markdown: bool = True,
allow_html: bool = False) -> Awaitable[EventID]:
return self.respond(content, event_type, markdown=markdown, edits=self,
allow_html=allow_html)
def edit(
self,
content: str | MessageEventContent,
event_type: EventType = EventType.ROOM_MESSAGE,
markdown: bool = True,
allow_html: bool = False,
) -> Awaitable[EventID]:
return self.respond(
content, event_type, markdown=markdown, edits=self, allow_html=allow_html
)
class MaubotMatrixClient(MatrixClient):
@ -110,11 +142,17 @@ class MaubotMatrixClient(MatrixClient):
super().__init__(*args, **kwargs)
self.disable_replies = False
async def send_markdown(self, room_id: RoomID, markdown: str, *, allow_html: bool = False,
msgtype: MessageType = MessageType.TEXT,
edits: Optional[Union[EventID, MessageEvent]] = None,
relates_to: Optional[RelatesTo] = None, **kwargs
) -> EventID:
async def send_markdown(
self,
room_id: RoomID,
markdown: str,
*,
allow_html: bool = False,
msgtype: MessageType = MessageType.TEXT,
edits: EventID | MessageEvent | None = None,
relates_to: RelatesTo | None = None,
**kwargs,
) -> EventID:
content = TextMessageEventContent(msgtype=msgtype, format=Format.HTML)
content.body, content.formatted_body = await parse_formatted(
markdown, allow_html=allow_html
@ -127,7 +165,7 @@ class MaubotMatrixClient(MatrixClient):
content.set_edit(edits)
return await self.send_message(room_id, content, **kwargs)
def dispatch_event(self, event: Event, source: SyncStream) -> List[asyncio.Task]:
def dispatch_event(self, event: Event, source: SyncStream) -> list[asyncio.Task]:
if isinstance(event, MessageEvent) and not isinstance(event, MaubotMessageEvent):
event = MaubotMessageEvent(event, self)
elif source != SyncStream.INTERNAL:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,38 +13,50 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Optional, TYPE_CHECKING
from __future__ import annotations
from typing import TYPE_CHECKING
from abc import ABC
from asyncio import AbstractEventLoop
from sqlalchemy.engine.base import Engine
from aiohttp import ClientSession
from sqlalchemy.engine.base import Engine
from yarl import URL
if TYPE_CHECKING:
from mautrix.util.logging import TraceLogger
from mautrix.util.config import BaseProxyConfig
from mautrix.util.logging import TraceLogger
from .client import MaubotMatrixClient
from .plugin_server import PluginWebApp
from .loader import BasePluginLoader
from .plugin_server import PluginWebApp
class Plugin(ABC):
client: 'MaubotMatrixClient'
client: MaubotMatrixClient
http: ClientSession
id: str
log: 'TraceLogger'
log: TraceLogger
loop: AbstractEventLoop
loader: 'BasePluginLoader'
config: Optional['BaseProxyConfig']
database: Optional[Engine]
webapp: Optional['PluginWebApp']
webapp_url: Optional[URL]
loader: BasePluginLoader
config: BaseProxyConfig | None
database: Engine | None
webapp: PluginWebApp | None
webapp_url: URL | None
def __init__(self, client: 'MaubotMatrixClient', loop: AbstractEventLoop, http: ClientSession,
instance_id: str, log: 'TraceLogger', config: Optional['BaseProxyConfig'],
database: Optional[Engine], webapp: Optional['PluginWebApp'],
webapp_url: Optional[str], loader: 'BasePluginLoader') -> None:
def __init__(
self,
client: MaubotMatrixClient,
loop: AbstractEventLoop,
http: ClientSession,
instance_id: str,
log: TraceLogger,
config: BaseProxyConfig | None,
database: Engine | None,
webapp: PluginWebApp | None,
webapp_url: str | None,
loader: BasePluginLoader,
) -> None:
self.client = client
self.loop = loop
self.http = http
@ -74,8 +86,10 @@ class Plugin(ABC):
else:
if len(web_handlers) > 0 and self.webapp is None:
if not warned_webapp:
self.log.warning(f"{type(obj).__name__} has web handlers, but the webapp"
" feature isn't enabled in the plugin's maubot.yaml")
self.log.warning(
f"{type(obj).__name__} has web handlers, but the webapp"
" feature isn't enabled in the plugin's maubot.yaml"
)
warned_webapp = True
continue
for method, path, kwargs in web_handlers:
@ -107,7 +121,7 @@ class Plugin(ABC):
pass
@classmethod
def get_config_class(cls) -> Optional[Type['BaseProxyConfig']]:
def get_config_class(cls) -> type[BaseProxyConfig] | None:
return None
def on_external_config_update(self) -> None:

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,10 +13,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List, Callable, Awaitable
from __future__ import annotations
from typing import Awaitable, Callable
from functools import partial
from aiohttp import web, hdrs
from aiohttp import hdrs, web
from yarl import URL
Handler = Callable[[web.Request], Awaitable[web.Response]]
@ -26,7 +28,7 @@ Middleware = Callable[[web.Request, Handler], Awaitable[web.Response]]
class PluginWebApp(web.UrlDispatcher):
def __init__(self):
super().__init__()
self._middleware: List[Middleware] = []
self._middleware: list[Middleware] = []
def add_middleware(self, middleware: Middleware) -> None:
self._middleware.append(middleware)
@ -58,8 +60,8 @@ class PluginWebApp(web.UrlDispatcher):
class PrefixResource(web.Resource):
def __init__(self, prefix, *, name=None):
assert not prefix or prefix.startswith('/'), prefix
assert prefix in ('', '/') or not prefix.endswith('/'), prefix
assert not prefix or prefix.startswith("/"), prefix
assert prefix in ("", "/") or not prefix.endswith("/"), prefix
super().__init__(name=name)
self._prefix = URL.build(path=prefix).raw_path
@ -68,14 +70,14 @@ class PrefixResource(web.Resource):
return self._prefix
def get_info(self):
return {'path': self._prefix}
return {"path": self._prefix}
def url_for(self):
return URL.build(path=self._prefix, encoded=True)
def add_prefix(self, prefix):
assert prefix.startswith('/')
assert not prefix.endswith('/')
assert prefix.startswith("/")
assert not prefix.endswith("/")
assert len(prefix) > 1
self._prefix = prefix + self._prefix
@ -84,4 +86,3 @@ class PrefixResource(web.Resource):
def raw_match(self, path: str) -> bool:
return path and path.startswith(self._prefix)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2019 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,36 +13,40 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, Dict
import logging
from __future__ import annotations
import asyncio
import json
from yarl import URL
import logging
from aiohttp import web, hdrs
from aiohttp import hdrs, web
from aiohttp.abc import AbstractAccessLogger
from yarl import URL
import pkg_resources
from mautrix.api import PathBuilder, Method
from mautrix.api import Method, PathBuilder
from .config import Config
from .plugin_server import PrefixResource, PluginWebApp
from .__meta__ import __version__
from .config import Config
from .plugin_server import PluginWebApp, PrefixResource
class AccessLogger(AbstractAccessLogger):
def log(self, request: web.Request, response: web.Response, time: int):
self.logger.info(f'{request.remote} "{request.method} {request.path} '
f'{response.status} {response.body_length} '
f'in {round(time, 4)}s"')
self.logger.info(
f'{request.remote} "{request.method} {request.path} '
f"{response.status} {response.body_length} "
f'in {round(time, 4)}s"'
)
class MaubotServer:
log: logging.Logger = logging.getLogger("maubot.server")
plugin_routes: Dict[str, PluginWebApp]
plugin_routes: dict[str, PluginWebApp]
def __init__(self, management_api: web.Application, config: Config,
loop: asyncio.AbstractEventLoop) -> None:
def __init__(
self, management_api: web.Application, config: Config, loop: asyncio.AbstractEventLoop
) -> None:
self.loop = loop or asyncio.get_event_loop()
self.app = web.Application(loop=self.loop, client_max_size=100 * 1024 * 1024)
self.config = config
@ -57,13 +61,15 @@ class MaubotServer:
async def handle_plugin_path(self, request: web.Request) -> web.StreamResponse:
for path, app in self.plugin_routes.items():
if request.path.startswith(path):
request = request.clone(rel_url=request.rel_url
.with_path(request.rel_url.path[len(path):])
.with_query(request.query_string))
request = request.clone(
rel_url=request.rel_url.with_path(
request.rel_url.path[len(path) :]
).with_query(request.query_string)
)
return await app.handle(request)
return web.Response(status=404)
def get_instance_subapp(self, instance_id: str) -> Tuple[PluginWebApp, str]:
def get_instance_subapp(self, instance_id: str) -> tuple[PluginWebApp, str]:
subpath = self.config["server.plugin_base_path"] + instance_id
url = self.config["server.public_url"] + subpath
try:
@ -94,8 +100,9 @@ class MaubotServer:
ui_base = self.config["server.ui_base_path"]
if ui_base == "/":
ui_base = ""
directory = (self.config["server.override_resource_path"]
or pkg_resources.resource_filename("maubot", "management/frontend/build"))
directory = self.config[
"server.override_resource_path"
] or pkg_resources.resource_filename("maubot", "management/frontend/build")
self.app.router.add_static(f"{ui_base}/static", f"{directory}/static")
self.setup_static_root_files(directory, ui_base)
@ -115,8 +122,9 @@ class MaubotServer:
raise web.HTTPFound(f"{ui_base}/")
self.app.middlewares.append(frontend_404_middleware)
self.app.router.add_get(f"{ui_base}/", lambda _: web.Response(body=index_html,
content_type="text/html"))
self.app.router.add_get(
f"{ui_base}/", lambda _: web.Response(body=index_html, content_type="text/html")
)
self.app.router.add_get(ui_base, ui_base_redirect)
def setup_static_root_files(self, directory: str, ui_base: str) -> None:
@ -128,8 +136,9 @@ class MaubotServer:
for file, mime in files.items():
with open(f"{directory}/{file}", "rb") as stream:
data = stream.read()
self.app.router.add_get(f"{ui_base}/{file}", lambda _: web.Response(body=data,
content_type=mime))
self.app.router.add_get(
f"{ui_base}/{file}", lambda _: web.Response(body=data, content_type=mime)
)
# also set up a resource path for the public url path prefix config
# cut the prefix path from public_url
@ -143,8 +152,12 @@ class MaubotServer:
api_path = f"{public_url_path}{base_path}"
path_prefix_response_body = json.dumps({"api_path": api_path.rstrip("/")})
self.app.router.add_get(f"{ui_base}/paths.json", lambda _: web.Response(body=path_prefix_response_body,
content_type="application/json"))
self.app.router.add_get(
f"{ui_base}/paths.json",
lambda _: web.Response(
body=path_prefix_response_body, content_type="application/json"
),
)
def add_route(self, method: Method, path: PathBuilder, handler) -> None:
self.app.router.add_route(method.value, str(path), handler)
@ -161,9 +174,7 @@ class MaubotServer:
@staticmethod
async def version(_: web.Request) -> web.Response:
return web.json_response({
"version": __version__
})
return web.json_response({"version": __version__})
async def handle_transaction(self, request: web.Request) -> web.Response:
return web.Response(status=501)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,43 +13,53 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional, Type, cast
import logging.config
import importlib
from __future__ import annotations
from typing import cast
import argparse
import asyncio
import copy
import importlib
import logging.config
import os.path
import signal
import copy
import sys
from aiohttp import ClientSession, hdrs, web
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from aiohttp import web, hdrs, ClientSession
from yarl import URL
import sqlalchemy as sql
from mautrix.util.config import RecursiveDict, BaseMissingError
from mautrix.types import (
EventType,
Filter,
FilterID,
Membership,
RoomEventFilter,
RoomFilter,
StrippedStateEvent,
SyncToken,
)
from mautrix.util.config import BaseMissingError, RecursiveDict
from mautrix.util.db import Base
from mautrix.util.logging import TraceLogger
from mautrix.types import (Filter, RoomFilter, RoomEventFilter, StrippedStateEvent,
EventType, Membership, FilterID, SyncToken)
from ..__meta__ import __version__
from ..lib.store_proxy import SyncStoreProxy
from ..loader import PluginMeta
from ..matrix import MaubotMatrixClient
from ..plugin_base import Plugin
from ..plugin_server import PluginWebApp, PrefixResource
from ..loader import PluginMeta
from ..server import AccessLogger
from ..matrix import MaubotMatrixClient
from ..lib.store_proxy import SyncStoreProxy
from ..__meta__ import __version__
from .config import Config
from .loader import FileSystemLoader
from .database import NextBatch
from .loader import FileSystemLoader
crypto_import_error = None
try:
from mautrix.crypto import OlmMachine, PgCryptoStore, PgCryptoStateStore
from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore
from mautrix.util.async_db import Database as AsyncDatabase
except ImportError as err:
crypto_import_error = err
@ -57,15 +67,32 @@ except ImportError as err:
parser = argparse.ArgumentParser(
description="A plugin-based Matrix bot system -- standalone mode.",
prog="python -m maubot.standalone")
parser.add_argument("-c", "--config", type=str, default="config.yaml",
metavar="<path>", help="the path to your config file")
parser.add_argument("-b", "--base-config", type=str,
default="pkg://maubot.standalone/example-config.yaml",
metavar="<path>", help="the path to the example config "
"(for automatic config updates)")
parser.add_argument("-m", "--meta", type=str, default="maubot.yaml",
metavar="<path>", help="the path to your plugin metadata file")
prog="python -m maubot.standalone",
)
parser.add_argument(
"-c",
"--config",
type=str,
default="config.yaml",
metavar="<path>",
help="the path to your config file",
)
parser.add_argument(
"-b",
"--base-config",
type=str,
default="pkg://maubot.standalone/example-config.yaml",
metavar="<path>",
help="the path to the example config " "(for automatic config updates)",
)
parser.add_argument(
"-m",
"--meta",
type=str,
default="maubot.yaml",
metavar="<path>",
help="the path to your plugin metadata file",
)
args = parser.parse_args()
config = Config(args.config, args.base_config)
@ -92,7 +119,7 @@ else:
module = meta.modules[0]
main_class = meta.main_class
bot_module = importlib.import_module(module)
plugin: Type[Plugin] = getattr(bot_module, main_class)
plugin: type[Plugin] = getattr(bot_module, main_class)
loader = FileSystemLoader(os.path.dirname(args.meta))
log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}")
@ -110,8 +137,10 @@ access_token = config["user.credentials.access_token"]
crypto_store = crypto_db = state_store = None
if device_id and not OlmMachine:
log.warning("device_id set in config, but encryption dependencies not installed",
exc_info=crypto_import_error)
log.warning(
"device_id set in config, but encryption dependencies not installed",
exc_info=crypto_import_error,
)
elif device_id:
crypto_db = AsyncDatabase.create(config["database"], upgrade_table=PgCryptoStore.upgrade_table)
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=crypto_db)
@ -124,27 +153,25 @@ if not nb:
bot_config = None
if not meta.config and "base-config.yaml" in meta.extra_files:
log.warning("base-config.yaml in extra files, but config is not set to true. "
"Assuming legacy plugin and loading config.")
log.warning(
"base-config.yaml in extra files, but config is not set to true. "
"Assuming legacy plugin and loading config."
)
meta.config = True
if meta.config:
log.debug("Loading config")
config_class = plugin.get_config_class()
def load() -> CommentedMap:
return config["plugin_config"]
def load_base() -> RecursiveDict[CommentedMap]:
return RecursiveDict(config.load_base()["plugin_config"], CommentedMap)
def save(data: RecursiveDict[CommentedMap]) -> None:
config["plugin_config"] = data
config.save()
try:
bot_config = config_class(load=load, load_base=load_base, save=save)
bot_config.load_and_update()
@ -161,9 +188,11 @@ if meta.webapp:
async def _handle_plugin_request(req: web.Request) -> web.StreamResponse:
if req.path.startswith(web_base_path):
req = req.clone(rel_url=req.rel_url
.with_path(req.rel_url.path[len(web_base_path):])
.with_query(req.query_string))
req = req.clone(
rel_url=req.rel_url.with_path(req.rel_url.path[len(web_base_path) :]).with_query(
req.query_string
)
)
return await plugin_webapp.handle(req)
return web.Response(status=404)
@ -175,8 +204,8 @@ else:
loop = asyncio.get_event_loop()
client: Optional[MaubotMatrixClient] = None
bot: Optional[Plugin] = None
client: MaubotMatrixClient | None = None
bot: Plugin | None = None
async def main():
@ -185,10 +214,17 @@ async def main():
global client, bot
client_log = logging.getLogger("maubot.client").getChild(user_id)
client = MaubotMatrixClient(mxid=user_id, base_url=homeserver, token=access_token,
client_session=http_client, loop=loop, log=client_log,
sync_store=SyncStoreProxy(nb), state_store=state_store,
device_id=device_id)
client = MaubotMatrixClient(
mxid=user_id,
base_url=homeserver,
token=access_token,
client_session=http_client,
loop=loop,
log=client_log,
sync_store=SyncStoreProxy(nb),
state_store=state_store,
device_id=device_id,
)
client.ignore_first_sync = config["user.ignore_first_sync"]
client.ignore_initial_sync = config["user.ignore_initial_sync"]
if crypto_store:
@ -199,8 +235,10 @@ async def main():
client.crypto = OlmMachine(client, crypto_store, state_store)
crypto_device_id = await crypto_store.get_device_id()
if crypto_device_id and crypto_device_id != device_id:
log.fatal("Mismatching device ID in crypto store and config "
f"(store: {crypto_device_id}, config: {device_id})")
log.fatal(
"Mismatching device ID in crypto store and config "
f"(store: {crypto_device_id}, config: {device_id})"
)
sys.exit(10)
await client.crypto.load()
if not crypto_device_id:
@ -224,17 +262,23 @@ async def main():
log.fatal(f"User ID mismatch: configured {user_id}, but server said {whoami.user_id}")
sys.exit(11)
elif whoami.device_id and device_id and whoami.device_id != device_id:
log.fatal(f"Device ID mismatch: configured {device_id}, "
f"but server said {whoami.device_id}")
log.fatal(
f"Device ID mismatch: configured {device_id}, "
f"but server said {whoami.device_id}"
)
sys.exit(12)
log.debug(f"Confirmed connection as {whoami.user_id} / {whoami.device_id}")
break
if config["user.sync"]:
if not nb.filter_id:
nb.edit(filter_id=await client.create_filter(Filter(
room=RoomFilter(timeline=RoomEventFilter(limit=50)),
)))
nb.edit(
filter_id=await client.create_filter(
Filter(
room=RoomFilter(timeline=RoomEventFilter(limit=50)),
)
)
)
client.start(nb.filter_id)
if config["user.autojoin"]:
@ -252,9 +296,18 @@ async def main():
await client.set_displayname(displayname)
plugin_log = cast(TraceLogger, logging.getLogger("maubot.instance.__main__"))
bot = plugin(client=client, loop=loop, http=http_client, instance_id="__main__",
log=plugin_log, config=bot_config, database=db if meta.database else None,
webapp=plugin_webapp, webapp_url=public_url, loader=loader)
bot = plugin(
client=client,
loop=loop,
http=http_client,
instance_id="__main__",
log=plugin_log,
config=bot_config,
database=db if meta.database else None,
webapp=plugin_webapp,
webapp_url=public_url,
loader=loader,
)
await bot.internal_start()

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Optional
from __future__ import annotations
import sqlalchemy as sql
from mautrix.types import FilterID, SyncToken, UserID
from mautrix.util.db import Base
from mautrix.types import UserID, SyncToken, FilterID
class NextBatch(Base):
@ -29,5 +29,5 @@ class NextBatch(Base):
filter_id: FilterID = sql.Column(sql.String(255))
@classmethod
def get(cls, user_id: UserID) -> Optional['NextBatch']:
def get(cls, user_id: UserID) -> NextBatch | None:
return cls._select_one_or_none(cls.c.user_id == user_id)

View file

@ -1,5 +1,5 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2021 Tulir Asokan
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
@ -13,12 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import List
import os.path
from __future__ import annotations
import os
import os.path
from ..loader import BasePluginLoader
class FileSystemLoader(BasePluginLoader):
def __init__(self, path: str) -> None:
self.path = path
@ -34,8 +36,8 @@ class FileSystemLoader(BasePluginLoader):
async def read_file(self, path: str) -> bytes:
return self.sync_read_file(path)
def sync_list_files(self, directory: str) -> List[str]:
def sync_list_files(self, directory: str) -> list[str]:
return os.listdir(os.path.join(self.path, directory))
async def list_files(self, directory: str) -> List[str]:
async def list_files(self, directory: str) -> list[str]:
return self.sync_list_files(directory)

14
pyproject.toml Normal file
View file

@ -0,0 +1,14 @@
[tool.isort]
profile = "black"
force_to_top = "typing"
from_first = true
combine_as_imports = true
known_first_party = "mautrix"
line_length = 99
skip = ["maubot/management/frontend"]
[tool.black]
line-length = 99
target-version = ["py38"]
required-version = "22.1.0"
force-exclude = "maubot/management/frontend"

View file

@ -1,4 +1,4 @@
mautrix==0.15.0rc4
mautrix>=0.15.0,<0.16
aiohttp>=3,<4
yarl>=1,<2
SQLAlchemy>=1,<1.4