diff --git a/.editorconfig b/.editorconfig index 3d6370d..f1ad5ca 100644 --- a/.editorconfig +++ b/.editorconfig @@ -14,6 +14,3 @@ indent_size = 2 [spec.yaml] indent_size = 2 - -[CHANGELOG.md] -max_line_length = 80 diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml new file mode 100644 index 0000000..44f4c91 --- /dev/null +++ b/.github/FUNDING.yml @@ -0,0 +1 @@ +github: tulir diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml deleted file mode 100644 index 28d6df2..0000000 --- a/.github/workflows/python-lint.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Python lint - -on: [push, pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: "3.13" - - uses: isort/isort-action@master - with: - sortPaths: "./maubot" - - uses: psf/black@stable - with: - src: "./maubot" - version: "24.10.0" - - name: pre-commit - run: | - pip install pre-commit - pre-commit run -av trailing-whitespace - pre-commit run -av end-of-file-fixer - pre-commit run -av check-yaml - pre-commit run -av check-added-large-files diff --git a/.gitignore b/.gitignore index 9fd28ef..d475bc1 100644 --- a/.gitignore +++ b/.gitignore @@ -7,13 +7,10 @@ pip-selfcheck.json *.pyc __pycache__ -*.db* -*.log +*.db /*.yaml !example-config.yaml -!.pre-commit-config.yaml -/start logs/ plugins/ trash/ diff --git a/.gitlab-ci-plugin.yml b/.gitlab-ci-plugin.yml deleted file mode 100644 index 45ef06b..0000000 --- a/.gitlab-ci-plugin.yml +++ /dev/null @@ -1,29 +0,0 @@ -image: dock.mau.dev/maubot/maubot - -stages: -- build - -variables: - PYTHONPATH: /opt/maubot - -build: - stage: build - except: - - tags - script: - - python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_REF_NAME-$CI_COMMIT_SHORT_SHA.mbp - artifacts: - paths: - - "*.mbp" - expire_in: 365 days - -build tags: - stage: build - only: - - tags - script: - - python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_TAG.mbp - artifacts: - paths: - - "*.mbp" - expire_in: never diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 50d0c15..445a5e7 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -10,7 +10,7 @@ default: - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY build frontend: - image: node:22-alpine + image: node:16-alpine stage: build frontend before_script: [] variables: @@ -58,8 +58,8 @@ manifest: script: - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 - - if [ "$CI_COMMIT_BRANCH" = "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:latest $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:latest; fi - - if [ "$CI_COMMIT_BRANCH" != "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME; fi + - if [ $CI_COMMIT_BRANCH == "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:latest $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:latest; fi + - if [ $CI_COMMIT_BRANCH != "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME; fi - docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 @@ -70,6 +70,5 @@ build standalone amd64: script: - docker pull $CI_REGISTRY_IMAGE:standalone || true - docker build --pull --cache-from $CI_REGISTRY_IMAGE:standalone --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone . -f maubot/standalone/Dockerfile - - if [ "$CI_COMMIT_BRANCH" = "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:standalone && docker push $CI_REGISTRY_IMAGE:standalone; fi - - if [ "$CI_COMMIT_BRANCH" != "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone && docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone; fi - - docker rmi $CI_REGISTRY_IMAGE:standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone || true + - if [ $CI_COMMIT_BRANCH == "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:standalone && docker push $CI_REGISTRY_IMAGE:standalone; fi + - if [ $CI_COMMIT_BRANCH != "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone && docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone; fi diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 4a6328e..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,20 +0,0 @@ -repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 - hooks: - - id: trailing-whitespace - exclude_types: [markdown] - - id: end-of-file-fixer - - id: check-yaml - - id: check-added-large-files - - repo: https://github.com/psf/black - rev: 24.10.0 - hooks: - - id: black - language_version: python3 - files: ^maubot/.*\.pyi?$ - - repo: https://github.com/PyCQA/isort - rev: 5.13.2 - hooks: - - id: isort - files: ^maubot/.*\.pyi?$ diff --git a/CHANGELOG.md b/CHANGELOG.md deleted file mode 100644 index d9de2b7..0000000 --- a/CHANGELOG.md +++ /dev/null @@ -1,164 +0,0 @@ -# v0.5.2 (2025-05-05) - -* Improved tombstone handling to ensure that the tombstone sender has - permissions to invite users to the target room. -* Fixed autojoin and online flags not being applied if set during client - creation (thanks to [@bnsh] in [#258]). -* Fixed plugin web apps not being cleared properly when unloading plugins. - -[@bnsh]: https://github.com/bnsh -[#258]: https://github.com/maubot/maubot/pull/258 - -# v0.5.1 (2025-01-03) - -* Updated Docker image to Alpine 3.21. -* Updated media upload/download endpoints in management frontend - (thanks to [@domrim] in [#253]). -* Fixed plugin web app base path not including a trailing slash - (thanks to [@jkhsjdhjs] in [#240]). -* Changed markdown parsing to cut off plaintext body if necessary to allow - longer formatted messages. -* Updated dependencies to fix Python 3.13 compatibility. - -[@domrim]: https://github.com/domrim -[@jkhsjdhjs]: https://github.com/jkhsjdhjs -[#253]: https://github.com/maubot/maubot/pull/253 -[#240]: https://github.com/maubot/maubot/pull/240 - -# v0.5.0 (2024-08-24) - -* Dropped Python 3.9 support. -* Updated Docker image to Alpine 3.20. -* Updated mautrix-python to 0.20.6 to support authenticated media. -* Removed hard dependency on SQLAlchemy. -* Fixed `main_class` to default to being loaded from the last module instead of - the first if a module name is not explicitly specified. - * This was already the [documented behavior](https://docs.mau.fi/maubot/dev/reference/plugin-metadata.html), - and loading from the first module doesn't make sense due to import order. -* Added simple scheduler utility for running background tasks periodically or - after a certain delay. -* Added testing framework for plugins (thanks to [@abompard] in [#225]). -* Changed `mbc build` to ignore directories declared in `modules` that are - missing an `__init__.py` file. - * Importing the modules at runtime would fail and break the plugin. - To include non-code resources outside modules in the mbp archive, - use `extra_files` instead. - -[#225]: https://github.com/maubot/maubot/issues/225 -[@abompard]: https://github.com/abompard - -# v0.4.2 (2023-09-20) - -* Updated Pillow to 10.0.1. -* Updated Docker image to Alpine 3.18. -* Added logging for errors for /whoami errors when adding new bot accounts. -* Added support for using appservice tokens (including appservice encryption) - in standalone mode. - -# v0.4.1 (2023-03-15) - -* Added `in_thread` parameter to `evt.reply()` and `evt.respond()`. - * By default, responses will go to the thread if the command is in a thread. - * By setting the flag to `True` or `False`, the plugin can force the response - to either be or not be in a thread. -* Fixed static files like the frontend app manifest not being served correctly. -* Fixed `self.loader.meta` not being available to plugins in standalone mode. -* Updated to mautrix-python v0.19.6. - -# v0.4.0 (2023-01-29) - -* Dropped support for using a custom maubot API base path. - * The public URL can still have a path prefix, e.g. when using a reverse - proxy. Both the web interface and `mbc` CLI tool should work fine with - custom prefixes. -* Added `evt.redact()` as a shortcut for `self.client.redact(evt.room_id, evt.event_id)`. -* Fixed `mbc logs` command not working on Python 3.8+. -* Fixed saving plugin configs (broke in v0.3.0). -* Fixed SSO login using the wrong API path (probably broke in v0.3.0). -* Stopped using `cd` in the docker image's `mbc` wrapper to enable using - path-dependent commands like `mbc build` by mounting a directory. -* Updated Docker image to Alpine 3.17. - -# v0.3.1 (2022-03-29) - -* Added encryption dependencies to standalone dockerfile. -* Fixed running without encryption dependencies installed. -* Removed unnecessary imports that broke on SQLAlchemy 1.4+. -* Removed unused alembic dependency. - -# v0.3.0 (2022-03-28) - -* Dropped Python 3.7 support. -* Switched main maubot database to asyncpg/aiosqlite. - * Using the same SQLite database for crypto is now safe again. -* Added support for asyncpg/aiosqlite for plugin databases. - * There are some [basic docs](https://docs.mau.fi/maubot/dev/database/index.html) - and [a simple example](./examples/database) for the new system. - * The old SQLAlchemy system is now deprecated, but will be preserved for - backwards-compatibility until most plugins have updated. -* Started enforcing minimum maubot version in plugins. - * Trying to upload a plugin where the specified version is higher than the - running maubot version will fail. -* Fixed bug where uploading a plugin twice, deleting it and trying to upload - again would fail. -* Updated Docker image to Alpine 3.15. -* Formatted all code using [black](https://github.com/psf/black) - and [isort](https://github.com/PyCQA/isort). - -# v0.2.1 (2021-11-22) - -Docker-only release: added automatic moving of plugin databases from -`/data/plugins/*.db` to `/data/dbs` - -# v0.2.0 (2021-11-20) - -* Moved plugin databases from `/data/plugins` to `/data/dbs` in the docker image. - * v0.2.0 was missing the automatic migration of databases, it was added in v0.2.1. - * If you were using a custom path, you'll have to mount it at `/data/dbs` or - move the databases yourself. -* Removed support for pickle crypto store and added support for SQLite crypto store. - * **If you were previously using the dangerous pickle store for e2ee, you'll - have to re-login with the bots (which can now be done conveniently with - `mbc auth --update-client`).** -* Added SSO support to `mbc auth`. -* Added support for setting device ID for e2ee using the web interface. -* Added e2ee fingerprint field to the web interface. -* Added `--update-client` flag to store access token inside maubot instead of - returning it in `mbc auth`. - * This will also automatically store the device ID now. -* Updated standalone mode. - * Added e2ee and web server support. - * It's now officially supported and [somewhat documented](https://docs.mau.fi/maubot/usage/standalone.html). -* Replaced `_` with `-` when generating command name from function name. -* Replaced unmaintained PyInquirer dependency with questionary - (thanks to [@TinfoilSubmarine] in [#139]). -* Updated Docker image to Alpine 3.14. -* Fixed avatar URLs without the `mxc://` prefix appearing like they work in the - frontend, but not actually working when saved. - -[@TinfoilSubmarine]: https://github.com/TinfoilSubmarine -[#139]: https://github.com/maubot/maubot/pull/139 - -# v0.1.2 (2021-06-12) - -* Added `loader` instance property for plugins to allow reading files within - the plugin archive. -* Added support for reloading `webapp` and `database` meta flags in plugins. - Previously you had to restart maubot instead of just reloading the plugin - when enabling the webapp or database for the first time. -* Added warning log if a plugin uses `@web` decorators without enabling the - `webapp` meta flag. -* Updated frontend to latest React and dependency versions. -* Updated Docker image to Alpine 3.13. -* Fixed registering accounts with Synapse shared secret registration. -* Fixed plugins using `get_event` in encrypted rooms. -* Fixed using the `@command.new` decorator without specifying a name - (i.e. falling back to the function name). - -# v0.1.1 (2021-05-02) - -No changelog. - -# v0.1.0 (2020-10-04) - -Initial tagged release. diff --git a/Dockerfile b/Dockerfile index 2c6bad4..1b8eb4d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,34 +1,36 @@ -FROM node:22 AS frontend-builder +FROM node:16 AS frontend-builder COPY ./maubot/management/frontend /frontend RUN cd /frontend && yarn --prod && yarn build -FROM alpine:3.21 +FROM alpine:3.13 RUN apk add --no-cache \ python3 py3-pip py3-setuptools py3-wheel \ ca-certificates \ su-exec \ - yq \ py3-aiohttp \ + py3-sqlalchemy \ py3-attrs \ py3-bcrypt \ py3-cffi \ + py3-psycopg2 \ py3-ruamel.yaml \ py3-jinja2 \ py3-click \ py3-packaging \ py3-markdown \ py3-alembic \ - py3-cssselect \ +# py3-cssselect \ py3-commonmark \ py3-pygments \ py3-tz \ +# py3-tzlocal \ py3-regex \ py3-wcwidth \ # encryption py3-cffi \ - py3-olm \ + olm-dev \ py3-pycryptodome \ py3-unpaddedbase64 \ py3-future \ @@ -38,20 +40,21 @@ RUN apk add --no-cache \ py3-feedparser \ py3-dateutil \ py3-lxml \ - py3-semver + py3-gitlab +# py3-semver@edge # TODO remove pillow, magic, feedparser, lxml, gitlab and semver when maubot supports installing dependencies COPY requirements.txt /opt/maubot/requirements.txt COPY optional-requirements.txt /opt/maubot/optional-requirements.txt WORKDIR /opt/maubot RUN apk add --virtual .build-deps python3-dev build-base git \ - && pip3 install --break-system-packages -r requirements.txt -r optional-requirements.txt \ - dateparser langdetect python-gitlab pyquery tzlocal \ + && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \ + && pip3 install -r requirements.txt -r optional-requirements.txt \ + dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \ && apk del .build-deps # TODO also remove dateparser, langdetect and pyquery when maubot supports installing dependencies COPY . /opt/maubot -RUN cp maubot/example-config.yaml . COPY ./docker/mbc.sh /usr/local/bin/mbc COPY --from=frontend-builder /frontend/build /opt/maubot/frontend ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data diff --git a/Dockerfile.ci b/Dockerfile.ci index 9712a16..950fac1 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -1,14 +1,15 @@ -FROM alpine:3.21 +FROM alpine:3.13 RUN apk add --no-cache \ python3 py3-pip py3-setuptools py3-wheel \ ca-certificates \ su-exec \ - yq \ py3-aiohttp \ + py3-sqlalchemy \ py3-attrs \ py3-bcrypt \ py3-cffi \ + py3-psycopg2 \ py3-ruamel.yaml \ py3-jinja2 \ py3-click \ @@ -24,7 +25,7 @@ RUN apk add --no-cache \ py3-wcwidth \ # encryption py3-cffi \ - py3-olm \ + olm-dev \ py3-pycryptodome \ py3-unpaddedbase64 \ py3-future \ @@ -32,8 +33,8 @@ RUN apk add --no-cache \ py3-pillow \ py3-magic \ py3-feedparser \ - py3-lxml -# py3-gitlab + py3-lxml \ + py3-gitlab # py3-semver # TODO remove pillow, magic, feedparser, lxml, gitlab and semver when maubot supports installing dependencies @@ -41,13 +42,13 @@ COPY requirements.txt /opt/maubot/requirements.txt COPY optional-requirements.txt /opt/maubot/optional-requirements.txt WORKDIR /opt/maubot RUN apk add --virtual .build-deps python3-dev build-base git \ - && pip3 install --break-system-packages -r requirements.txt -r optional-requirements.txt \ - dateparser langdetect python-gitlab pyquery semver tzlocal cssselect \ + && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \ + && pip3 install -r requirements.txt -r optional-requirements.txt \ + dateparser langdetect pyquery cchardet semver tzlocal cssselect \ && apk del .build-deps # TODO also remove dateparser, langdetect and pyquery when maubot supports installing dependencies COPY . /opt/maubot -RUN cp /opt/maubot/maubot/example-config.yaml /opt/maubot COPY ./docker/mbc.sh /usr/local/bin/mbc ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data VOLUME /data diff --git a/Dockerfile.local b/Dockerfile.local deleted file mode 100644 index d37e220..0000000 --- a/Dockerfile.local +++ /dev/null @@ -1,29 +0,0 @@ -FROM r.batts.cloud/nodejs:22 AS frontend-builder - -COPY ./maubot/management/frontend /frontend -RUN cd /frontend && yarn --prod && yarn build - -FROM r.batts.cloud/debian:bookworm - -RUN apt update && \ - apt install -y --no-install-recommends python3 python3-dev python3-venv python3-semver git gosu yq brotli && \ - apt clean -y && \ - rm -rf /var/lib/apt/lists/* - -COPY requirements.txt /opt/maubot/requirements.txt -COPY optional-requirements.txt /opt/maubot/optional-requirements.txt -WORKDIR /opt/maubot -RUN python3 -m venv /venv \ - && bash -c 'source /venv/bin/activate \ - && pip3 install -r requirements.txt -r optional-requirements.txt \ - dateparser langdetect python-gitlab pyquery tzlocal pyfiglet emoji feedparser brotli' -# TODO also remove pyfiglet, emoji, dateparser, langdetect and pyquery when maubot supports installing dependencies - -COPY . /opt/maubot -RUN cp maubot/example-config.yaml . -COPY ./docker/mbc.sh /usr/local/bin/mbc -COPY --from=frontend-builder /frontend/build /opt/maubot/frontend -ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data -VOLUME /data - -CMD ["/opt/maubot/docker/run.sh"] diff --git a/MANIFEST.in b/MANIFEST.in index d8889bc..daa36da 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,4 @@ include README.md -include CHANGELOG.md include LICENSE include requirements.txt include optional-requirements.txt diff --git a/README.md b/README.md index 02a4b6f..75be206 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,4 @@ # maubot -![Languages](https://img.shields.io/github/languages/top/maubot/maubot.svg) -[![License](https://img.shields.io/github/license/maubot/maubot.svg)](LICENSE) -[![Release](https://img.shields.io/github/release/maubot/maubot/all.svg)](https://github.com/maubot/maubot/releases) -[![GitLab CI](https://mau.dev/maubot/maubot/badges/master/pipeline.svg)](https://mau.dev/maubot/maubot/container_registry) -[![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) -[![Imports](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) - A plugin-based [Matrix](https://matrix.org) bot system written in Python. ## Documentation @@ -22,8 +15,41 @@ All setup and usage instructions are located on Matrix room: [#maubot:maunium.net](https://matrix.to/#/#maubot:maunium.net) ## Plugins -A list of plugins can be found at [plugins.mau.bot](https://plugins.mau.bot/). +* [jesaribot](https://github.com/maubot/jesaribot) - A simple bot that replies with an image when you say "jesari". +* [sed](https://github.com/maubot/sed) - A bot to do sed-like replacements. +* [factorial](https://github.com/maubot/factorial) - A bot to calculate unexpected factorials. +* [media](https://github.com/maubot/media) - A bot that replies with the MXC URI of images you send it. +* [dice](https://github.com/maubot/dice) - A combined dice rolling and calculator bot. +* [karma](https://github.com/maubot/karma) - A user karma tracker bot. +* [xkcd](https://github.com/maubot/xkcd) - A bot to view xkcd comics. +* [echo](https://github.com/maubot/echo) - A bot that echoes pings and other stuff. +* [rss](https://github.com/maubot/rss) - A bot that posts RSS feed updates to Matrix. +* [reddit](https://github.com/TomCasavant/RedditMaubot) - A bot that condescendingly corrects a user when they enter an r/subreddit without providing a link to that subreddit +* [giphy](https://github.com/TomCasavant/GiphyMaubot) - A bot that generates a gif (from giphy) given search terms +* [trump](https://github.com/jeffcasavant/MaubotTrumpTweet) - A bot that generates a Trump tweet with the given content +* [poll](https://github.com/TomCasavant/PollMaubot) - A bot that will create a simple poll for users in a room +* [urban](https://github.com/dvdgsng/UrbanMaubot) - A bot that fetches definitions from [Urban Dictionary](https://www.urbandictionary.com/). +* [reminder](https://github.com/maubot/reminder) - A bot to remind you about things. +* [translate](https://github.com/maubot/translate) - A bot to translate words. +* [reactbot](https://github.com/maubot/reactbot) - A bot that responds to messages that match predefined rules. +* [exec](https://github.com/maubot/exec) - A bot that executes code. +* [commitstrip](https://github.com/maubot/commitstrip) - A bot to view CommitStrips. +* [supportportal](https://github.com/maubot/supportportal) - A bot to manage customer support on Matrix. +* [gitlab](https://github.com/maubot/gitlab) - A GitLab client and webhook receiver. +* [github](https://github.com/maubot/github) - A GitHub client and webhook receiver. +* [gitea](https://github.com/saces/maugitea) - A Gitea client and webhook receiver. +* [twilio](https://github.com/jeffcasavant/MaubotTwilio) - Maubot-based SMS bridge +* [tmdb](https://codeberg.org/lomion/tmdb-bot) - A bot that posts information about movies fetched from TheMovieDB.org. +* [tex](https://github.com/maubot/tex) - A bot that renders LaTeX. +* [altalias](https://github.com/maubot/altalias) - A bot that lets users publish alternate aliases in rooms. +* [satwcomic](https://github.com/maubot/satwcomic) - A bot to view SatWComics. +* [songwhip](https://github.com/maubot/songwhip) - A bot to post Songwhip links. +* [invite](https://github.com/williamkray/maubot-invite) - A bot to generate invitation tokens from [matrix-registration](https://github.com/ZerataX/matrix-registration) +* [wolframalpha](https://github.com/ggogel/WolframAlphaMaubot) - A bot that allows requesting information from [WolframAlpha](https://www.wolframalpha.com/). +* [pingcheck](https://edugit.org/nik/maubot-pingcheck) - A bot to ping the echo bot and send rtt to Icinga passive check +* [ticker](https://github.com/williamkray/maubot-ticker) - A bot to return financial data about a stock or cryptocurrency. +* [weather](https://github.com/kellya/maubot-weather) - A bot to get the weather from wttr.in and return a single line of text for the location specified -To add your plugin to the list, send a pull request to . +Open a pull request or join the Matrix room linked above to get your plugin listed here -The plugin wishlist lives at . +The plugin wishlist lives at https://github.com/maubot/plugin-wishlist/issues diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..0d78e89 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,83 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration files +# file_template = %%(rev)s_%%(slug)s + +# timezone to use when rendering the date +# within the migration file as well as the filename. +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; this defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path +# version_locations = %(here)s/bar %(here)s/bat alembic/versions + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks=black +# black.type=console_scripts +# black.entrypoint=black +# black.options=-l 79 + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 0000000..98e4f9c --- /dev/null +++ b/alembic/README @@ -0,0 +1 @@ +Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..9946810 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,92 @@ +from logging.config import fileConfig + +from sqlalchemy import engine_from_config, pool + +from alembic import context + +import sys +from os.path import abspath, dirname + +sys.path.insert(0, dirname(dirname(abspath(__file__)))) + +from mautrix.util.db import Base +from maubot.config import Config +from maubot import db + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml") +maubot_config = Config(maubot_config_path, None) +maubot_config.load() +config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%")) + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline(): + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + render_as_batch=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online(): + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = engine_from_config( + config.get_section(config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata, + render_as_batch=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..2c01563 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/4b93300852aa_add_device_id_to_clients.py b/alembic/versions/4b93300852aa_add_device_id_to_clients.py new file mode 100644 index 0000000..efc71cd --- /dev/null +++ b/alembic/versions/4b93300852aa_add_device_id_to_clients.py @@ -0,0 +1,32 @@ +"""Add device_id to clients + +Revision ID: 4b93300852aa +Revises: fccd1f95544d +Create Date: 2020-07-11 15:49:38.831459 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '4b93300852aa' +down_revision = 'fccd1f95544d' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('client', schema=None) as batch_op: + batch_op.add_column(sa.Column('device_id', sa.String(length=255), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('client', schema=None) as batch_op: + batch_op.drop_column('device_id') + + # ### end Alembic commands ### diff --git a/alembic/versions/90aa88820eab_add_matrix_state_store.py b/alembic/versions/90aa88820eab_add_matrix_state_store.py new file mode 100644 index 0000000..37a68eb --- /dev/null +++ b/alembic/versions/90aa88820eab_add_matrix_state_store.py @@ -0,0 +1,47 @@ +"""Add Matrix state store + +Revision ID: 90aa88820eab +Revises: 4b93300852aa +Create Date: 2020-07-12 01:50:06.215623 + +""" +from alembic import op +import sqlalchemy as sa + +from mautrix.client.state_store.sqlalchemy import SerializableType +from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent + + +# revision identifiers, used by Alembic. +revision = '90aa88820eab' +down_revision = '4b93300852aa' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('mx_room_state', + sa.Column('room_id', sa.String(length=255), nullable=False), + sa.Column('is_encrypted', sa.Boolean(), nullable=True), + sa.Column('has_full_member_list', sa.Boolean(), nullable=True), + sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True), + sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True), + sa.PrimaryKeyConstraint('room_id') + ) + op.create_table('mx_user_profile', + sa.Column('room_id', sa.String(length=255), nullable=False), + sa.Column('user_id', sa.String(length=255), nullable=False), + sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False), + sa.Column('displayname', sa.String(), nullable=True), + sa.Column('avatar_url', sa.String(length=255), nullable=True), + sa.PrimaryKeyConstraint('room_id', 'user_id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('mx_user_profile') + op.drop_table('mx_room_state') + # ### end Alembic commands ### diff --git a/alembic/versions/d295f8dcfa64_initial_revision.py b/alembic/versions/d295f8dcfa64_initial_revision.py new file mode 100644 index 0000000..ffa502f --- /dev/null +++ b/alembic/versions/d295f8dcfa64_initial_revision.py @@ -0,0 +1,50 @@ +"""Initial revision + +Revision ID: d295f8dcfa64 +Revises: +Create Date: 2019-09-27 00:21:02.527915 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd295f8dcfa64' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('client', + sa.Column('id', sa.String(length=255), nullable=False), + sa.Column('homeserver', sa.String(length=255), nullable=False), + sa.Column('access_token', sa.Text(), nullable=False), + sa.Column('enabled', sa.Boolean(), nullable=False), + sa.Column('next_batch', sa.String(length=255), nullable=False), + sa.Column('filter_id', sa.String(length=255), nullable=False), + sa.Column('sync', sa.Boolean(), nullable=False), + sa.Column('autojoin', sa.Boolean(), nullable=False), + sa.Column('displayname', sa.String(length=255), nullable=False), + sa.Column('avatar_url', sa.String(length=255), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('plugin', + sa.Column('id', sa.String(length=255), nullable=False), + sa.Column('type', sa.String(length=255), nullable=False), + sa.Column('enabled', sa.Boolean(), nullable=False), + sa.Column('primary_user', sa.String(length=255), nullable=False), + sa.Column('config', sa.Text(), nullable=False), + sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('plugin') + op.drop_table('client') + # ### end Alembic commands ### diff --git a/alembic/versions/fccd1f95544d_add_online_field_to_clients.py b/alembic/versions/fccd1f95544d_add_online_field_to_clients.py new file mode 100644 index 0000000..1f7eabe --- /dev/null +++ b/alembic/versions/fccd1f95544d_add_online_field_to_clients.py @@ -0,0 +1,30 @@ +"""Add online field to clients + +Revision ID: fccd1f95544d +Revises: d295f8dcfa64 +Create Date: 2020-03-06 15:07:50.136644 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'fccd1f95544d' +down_revision = 'd295f8dcfa64' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("client") as batch_op: + batch_op.add_column(sa.Column('online', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true())) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("client") as batch_op: + batch_op.drop_column('online') + # ### end Alembic commands ### diff --git a/dev-requirements.txt b/dev-requirements.txt deleted file mode 100644 index bb8c2a0..0000000 --- a/dev-requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -pre-commit>=2.10.1,<3 -isort>=5.10.1,<6 -black>=24,<25 diff --git a/docker/example-config.yaml b/docker/example-config.yaml new file mode 100644 index 0000000..192a420 --- /dev/null +++ b/docker/example-config.yaml @@ -0,0 +1,108 @@ +# The full URI to the database. SQLite and Postgres are fully supported. +# Other DBMSes supported by SQLAlchemy may or may not work. +# Format examples: +# SQLite: sqlite:///filename.db +# Postgres: postgres://username:password@hostname/dbname +database: sqlite:////data/maubot.db + +# Database for encryption data. +crypto_database: + # Type of database. Either "default", "pickle" or "postgres". + # When set to default, using SQLite as the main database will use pickle as the crypto database + # and using Postgres as the main database will use the same one as the crypto database. + # + # When using pickle, individual crypto databases are stored in the pickle_dir directory. + # When using non-default postgres, postgres_uri is used to connect to postgres. + type: default + postgres_uri: postgres://username:password@hostname/dbname + pickle_dir: /data/crypto + +plugin_directories: + # The directory where uploaded new plugins should be stored. + upload: /data/plugins + # The directories from which plugins should be loaded. + # Duplicate plugin IDs will be moved to the trash. + load: + - /data/plugins + # The directory where old plugin versions and conflicting plugins should be moved. + # Set to "delete" to delete files immediately. + trash: /data/trash + # The directory where plugin databases should be stored. + db: /data/plugins + +server: + # The IP and port to listen to. + hostname: 0.0.0.0 + port: 29316 + # Public base URL where the server is visible. + public_url: https://example.com + # The base management API path. + base_path: /_matrix/maubot/v1 + # The base path for the UI. + ui_base_path: /_matrix/maubot + # The base path for plugin endpoints. The instance ID will be appended directly. + plugin_base_path: /_matrix/maubot/plugin/ + # Override path from where to load UI resources. + # Set to false to using pkg_resources to find the path. + override_resource_path: /opt/maubot/frontend + # The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1. + appservice_base_path: /_matrix/app/v1 + # The shared secret to sign API access tokens. + # Set to "generate" to generate and save a new token at startup. + unshared_secret: generate + +# Shared registration secrets to allow registering new users from the management UI +registration_secrets: + example.com: + # Client-server API URL + url: https://example.com + # registration_shared_secret from synapse config + secret: synapse_shared_registration_secret + +# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password +# to prevent normal login. Root is a special user that can't have a password and will always exist. +admins: + root: "" + +# API feature switches. +api_features: + login: true + plugin: true + plugin_upload: true + instance: true + instance_database: true + client: true + client_proxy: true + client_auth: true + dev_open: true + log: true + +# Python logging configuration. +# +# See section 16.7.2 of the Python documentation for more info: +# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema +logging: + version: 1 + formatters: + precise: + format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" + handlers: + file: + class: logging.handlers.RotatingFileHandler + formatter: precise + filename: /var/log/maubot.log + maxBytes: 10485760 + backupCount: 10 + console: + class: logging.StreamHandler + formatter: precise + loggers: + maubot: + level: DEBUG + mautrix: + level: DEBUG + aiohttp: + level: INFO + root: + level: DEBUG + handlers: [file, console] diff --git a/docker/mbc.sh b/docker/mbc.sh index 5bde65a..bffbd5e 100755 --- a/docker/mbc.sh +++ b/docker/mbc.sh @@ -1,3 +1,3 @@ #!/bin/sh -export PYTHONPATH=/opt/maubot +cd /opt/maubot python3 -m maubot.cli "$@" diff --git a/docker/run.sh b/docker/run.sh index 1ec95a2..96a60b9 100755 --- a/docker/run.sh +++ b/docker/run.sh @@ -1,50 +1,21 @@ -#!/bin/bash +#!/bin/sh function fixperms { - chown -R $UID:$GID /var/log /data -} - -function fixdefault { - _value=$(yq e "$1" /data/config.yaml) - if [[ "$_value" == "$2" ]]; then - yq e -i "$1 = "'"'"$3"'"' /data/config.yaml - fi -} - -function fixconfig { - # Change relative default paths to absolute paths in /data - fixdefault '.database' 'sqlite:maubot.db' 'sqlite:/data/maubot.db' - fixdefault '.plugin_directories.upload' './plugins' '/data/plugins' - fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins' - fixdefault '.plugin_directories.trash' './trash' '/data/trash' - fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs' - fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs' - fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log' - # This doesn't need to be configurable - yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml + chown -R $UID:$GID /var/log /data /opt/maubot } cd /opt/maubot -mkdir -p /var/log/maubot /data/plugins /data/trash /data/dbs +mkdir -p /var/log/maubot /data/plugins /data/trash /data/dbs /data/crypto if [ ! -f /data/config.yaml ]; then - cp example-config.yaml /data/config.yaml + cp docker/example-config.yaml /data/config.yaml echo "Config file not found. Example config copied to /data/config.yaml" echo "Please modify the config file to your liking and restart the container." fixperms - fixconfig exit fi +alembic -x config=/data/config.yaml upgrade head fixperms -fixconfig -if ls /data/plugins/*.db > /dev/null 2>&1; then - mv -n /data/plugins/*.db /data/dbs/ -fi - -if [ -f "/venv/bin/activate" ] ; then - exec gosu $UID:$GID bash -c 'source /venv/bin/activate && python3 -m maubot -c /data/config.yaml' -fi -exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml - +exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml -b docker/example-config.yaml diff --git a/example-config.yaml b/example-config.yaml new file mode 100644 index 0000000..89d3965 --- /dev/null +++ b/example-config.yaml @@ -0,0 +1,113 @@ +# The full URI to the database. SQLite and Postgres are fully supported. +# Other DBMSes supported by SQLAlchemy may or may not work. +# Format examples: +# SQLite: sqlite:///filename.db +# Postgres: postgres://username:password@hostname/dbname +database: sqlite:///maubot.db + +# Database for encryption data. +crypto_database: + # Type of database. Either "default", "pickle" or "postgres". + # When set to default, using SQLite as the main database will use pickle as the crypto database + # and using Postgres as the main database will use the same one as the crypto database. + # + # When using pickle, individual crypto databases are stored in the pickle_dir directory. + # When using non-default postgres, postgres_uri is used to connect to postgres. + # + # WARNING: The pickle database is dangerous and should not be used in production. + type: default + postgres_uri: postgres://username:password@hostname/dbname + pickle_dir: ./crypto + +plugin_directories: + # The directory where uploaded new plugins should be stored. + upload: ./plugins + # The directories from which plugins should be loaded. + # Duplicate plugin IDs will be moved to the trash. + load: + - ./plugins + # The directory where old plugin versions and conflicting plugins should be moved. + # Set to "delete" to delete files immediately. + trash: ./trash + # The directory where plugin databases should be stored. + db: ./plugins + +server: + # The IP and port to listen to. + hostname: 0.0.0.0 + port: 29316 + # Public base URL where the server is visible. + public_url: https://example.com + # The base management API path. + base_path: /_matrix/maubot/v1 + # The base path for the UI. + ui_base_path: /_matrix/maubot + # The base path for plugin endpoints. The instance ID will be appended directly. + plugin_base_path: /_matrix/maubot/plugin/ + # Override path from where to load UI resources. + # Set to false to using pkg_resources to find the path. + override_resource_path: false + # The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1. + appservice_base_path: /_matrix/app/v1 + # The shared secret to sign API access tokens. + # Set to "generate" to generate and save a new token at startup. + unshared_secret: generate + +# Shared registration secrets to allow registering new users from the management UI +registration_secrets: + example.com: + # Client-server API URL + url: https://example.com + # registration_shared_secret from synapse config + secret: synapse_shared_registration_secret + +# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password +# to prevent normal login. Root is a special user that can't have a password and will always exist. +admins: + root: "" + +# API feature switches. +api_features: + login: true + plugin: true + plugin_upload: true + instance: true + instance_database: true + client: true + client_proxy: true + client_auth: true + dev_open: true + log: true + +# Python logging configuration. +# +# See section 16.7.2 of the Python documentation for more info: +# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema +logging: + version: 1 + formatters: + colored: + (): maubot.lib.color_log.ColorFormatter + format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" + normal: + format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" + handlers: + file: + class: logging.handlers.RotatingFileHandler + formatter: normal + filename: ./maubot.log + maxBytes: 10485760 + backupCount: 10 + console: + class: logging.StreamHandler + formatter: colored + loggers: + maubot: + level: DEBUG + mautrix: + level: DEBUG + aiohttp: + level: INFO + root: + level: DEBUG + handlers: [file, console] diff --git a/examples/LICENSE b/examples/LICENSE index a4b60f3..bfdfe68 100644 --- a/examples/LICENSE +++ b/examples/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2022 Tulir Asokan +Copyright (c) 2018 Tulir Asokan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/examples/README.md b/examples/README.md index 2efabca..1837fec 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,4 +4,3 @@ All examples are published under the [MIT license](LICENSE). * [Hello World](helloworld/) - Very basic event handling bot that responds "Hello, World!" to all messages. * [Echo bot](https://github.com/maubot/echo) - Basic command handling bot with !echo and !ping commands * [Config example](config/) - Simple example of using a config file -* [Database example](database/) - Simple example of using a database diff --git a/examples/config/base-config.yaml b/examples/config/base-config.yaml index 0f7f8a3..c621847 100644 --- a/examples/config/base-config.yaml +++ b/examples/config/base-config.yaml @@ -1,5 +1,2 @@ -# Who is allowed to use the bot? -whitelist: - - "@user:example.com" -# The prefix for the main command without the ! -command_prefix: hello-world +# Message to send when user sends !getmessage +message: Default configuration active diff --git a/examples/config/configurablebot.py b/examples/config/configurablebot.py index 54b47b6..13624be 100644 --- a/examples/config/configurablebot.py +++ b/examples/config/configurablebot.py @@ -1,4 +1,5 @@ from typing import Type + from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper from maubot import Plugin, MessageEvent from maubot.handlers import command @@ -6,22 +7,19 @@ from maubot.handlers import command class Config(BaseProxyConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: - helper.copy("whitelist") - helper.copy("command_prefix") + helper.copy("message") -class ConfigurableBot(Plugin): +class DatabaseBot(Plugin): async def start(self) -> None: + await super().start() self.config.load_and_update() - def get_command_name(self) -> str: - return self.config["command_prefix"] - - @command.new(name=get_command_name) - async def hmm(self, evt: MessageEvent) -> None: - if evt.sender in self.config["whitelist"]: - await evt.reply("You're whitelisted 🎉") - @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: return Config + + @command.new("getmessage") + async def handler(self, event: MessageEvent) -> None: + if event.sender != self.client.mxid: + await event.reply(self.config["message"]) diff --git a/examples/config/maubot.yaml b/examples/config/maubot.yaml index 8ab36a9..b049dba 100644 --- a/examples/config/maubot.yaml +++ b/examples/config/maubot.yaml @@ -1,12 +1,11 @@ maubot: 0.1.0 -id: xyz.maubot.configurablebot -version: 2.0.0 +id: xyz.maubot.databasebot +version: 1.0.0 license: MIT modules: - configurablebot main_class: ConfigurableBot database: false -config: true # Instruct the build tool to include the base config. extra_files: diff --git a/examples/database/maubot.yaml b/examples/database/maubot.yaml deleted file mode 100644 index 84f4f69..0000000 --- a/examples/database/maubot.yaml +++ /dev/null @@ -1,10 +0,0 @@ -maubot: 0.1.0 -id: xyz.maubot.storagebot -version: 2.0.0 -license: MIT -modules: -- storagebot -main_class: StorageBot -database: true -database_type: asyncpg -config: false diff --git a/examples/database/storagebot.py b/examples/database/storagebot.py deleted file mode 100644 index 786bba5..0000000 --- a/examples/database/storagebot.py +++ /dev/null @@ -1,72 +0,0 @@ -from __future__ import annotations - -from mautrix.util.async_db import UpgradeTable, Connection -from maubot import Plugin, MessageEvent -from maubot.handlers import command - -upgrade_table = UpgradeTable() - - -@upgrade_table.register(description="Initial revision") -async def upgrade_v1(conn: Connection) -> None: - await conn.execute( - """CREATE TABLE stored_data ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - )""" - ) - - -@upgrade_table.register(description="Remember user who added value") -async def upgrade_v2(conn: Connection) -> None: - await conn.execute("ALTER TABLE stored_data ADD COLUMN creator TEXT") - - -class StorageBot(Plugin): - @command.new() - async def storage(self, evt: MessageEvent) -> None: - pass - - @storage.subcommand(help="Store a value") - @command.argument("key") - @command.argument("value", pass_raw=True) - async def put(self, evt: MessageEvent, key: str, value: str) -> None: - q = """ - INSERT INTO stored_data (key, value, creator) VALUES ($1, $2, $3) - ON CONFLICT (key) DO UPDATE SET value=excluded.value, creator=excluded.creator - """ - await self.database.execute(q, key, value, evt.sender) - await evt.reply(f"Inserted {key} into the database") - - @storage.subcommand(help="Get a value from the storage") - @command.argument("key") - async def get(self, evt: MessageEvent, key: str) -> None: - q = "SELECT key, value, creator FROM stored_data WHERE LOWER(key)=LOWER($1)" - row = await self.database.fetchrow(q, key) - if row: - key = row["key"] - value = row["value"] - creator = row["creator"] - await evt.reply(f"`{key}` stored by {creator}:\n\n```\n{value}\n```") - else: - await evt.reply(f"No data stored under `{key}` :(") - - @storage.subcommand(help="List keys in the storage") - @command.argument("prefix", required=False) - async def list(self, evt: MessageEvent, prefix: str | None) -> None: - q = "SELECT key, creator FROM stored_data WHERE key LIKE $1" - rows = await self.database.fetch(q, prefix + "%") - prefix_reply = f" starting with `{prefix}`" if prefix else "" - if len(rows) == 0: - await evt.reply(f"Nothing{prefix_reply} stored in database :(") - else: - formatted_data = "\n".join( - f"* `{row['key']}` stored by {row['creator']}" for row in rows - ) - await evt.reply( - f"Found {len(rows)} keys{prefix_reply} in database:\n\n{formatted_data}" - ) - - @classmethod - def get_db_upgrade_table(cls) -> UpgradeTable | None: - return upgrade_table diff --git a/maubot/__init__.py b/maubot/__init__.py index 5106b46..9ca0322 100644 --- a/maubot/__init__.py +++ b/maubot/__init__.py @@ -1,4 +1,3 @@ -from .__meta__ import __version__ -from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent from .plugin_base import Plugin from .plugin_server import PluginWebApp +from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent diff --git a/maubot/__main__.py b/maubot/__main__.py index c4cba44..2ef73f9 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,171 +13,84 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - +import logging.config +import argparse import asyncio +import signal +import copy import sys -from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme -from mautrix.util.program import Program - -from .__meta__ import __version__ -from .client import Client from .config import Config -from .db import init as init_db, upgrade_table -from .instance import PluginInstance -from .lib.future_awaitable import FutureAwaitable -from .lib.state_store import PgStateStore -from .loader.zip import init as init_zip_loader -from .management.api import init as init_mgmt_api +from .db import init as init_db from .server import MaubotServer +from .client import Client, init as init_client_class +from .loader.zip import init as init_zip_loader +from .instance import init as init_plugin_instance_class +from .management.api import init as init_mgmt_api +from .__meta__ import __version__ + +parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.", + prog="python -m maubot") +parser.add_argument("-c", "--config", type=str, default="config.yaml", + metavar="", help="the path to your config file") +parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml", + metavar="", help="the path to the example config " + "(for automatic config updates)") +args = parser.parse_args() + +config = Config(args.config, args.base_config) +config.load() +config.update() + +logging.config.dictConfig(copy.deepcopy(config["logging"])) + +loop = asyncio.get_event_loop() + +stop_log_listener = None +if config["api_features.log"]: + from .management.api.log import init as init_log_listener, stop_all as stop_log_listener + + init_log_listener(loop) + +log = logging.getLogger("maubot.init") +log.info(f"Initializing maubot {__version__}") + +init_zip_loader(config) +db_engine = init_db(config) +clients = init_client_class(config, loop) +management_api = init_mgmt_api(config, loop) +server = MaubotServer(management_api, config, loop) +plugins = init_plugin_instance_class(config, server, loop) + +for plugin in plugins: + plugin.load() + +signal.signal(signal.SIGINT, signal.default_int_handler) +signal.signal(signal.SIGTERM, signal.default_int_handler) + try: - from mautrix.crypto.store import PgCryptoStore -except ImportError: - PgCryptoStore = None - - -class Maubot(Program): - config: Config - server: MaubotServer - db: Database - crypto_db: Database | None - plugin_postgres_db: PostgresDatabase | None - state_store: PgStateStore - - config_class = Config - module = "maubot" - name = "maubot" - version = __version__ - command = "python -m maubot" - description = "A plugin-based Matrix bot system." - - def prepare_log_websocket(self) -> None: - from .management.api.log import init, stop_all - - init(self.loop) - self.add_shutdown_actions(FutureAwaitable(stop_all)) - - def prepare_arg_parser(self) -> None: - super().prepare_arg_parser() - self.parser.add_argument( - "--ignore-unsupported-database", - action="store_true", - help="Run even if the database schema is too new", - ) - self.parser.add_argument( - "--ignore-foreign-tables", - action="store_true", - help="Run even if the database contains tables from other programs (like Synapse)", - ) - - def prepare_db(self) -> None: - self.db = Database.create( - self.config["database"], - upgrade_table=upgrade_table, - db_args=self.config["database_opts"], - owner_name=self.name, - ignore_foreign_tables=self.args.ignore_foreign_tables, - ) - init_db(self.db) - - if PgCryptoStore: - if self.config["crypto_database"] == "default": - self.crypto_db = self.db - else: - self.crypto_db = Database.create( - self.config["crypto_database"], - upgrade_table=PgCryptoStore.upgrade_table, - ignore_foreign_tables=self.args.ignore_foreign_tables, - ) - else: - self.crypto_db = None - - if self.config["plugin_databases.postgres"] == "default": - if self.db.scheme != Scheme.POSTGRES: - self.log.critical( - 'Using "default" as the postgres plugin database URL is only allowed if ' - "the default database is postgres." - ) - sys.exit(24) - assert isinstance(self.db, PostgresDatabase) - self.plugin_postgres_db = self.db - elif self.config["plugin_databases.postgres"]: - plugin_db = Database.create( - self.config["plugin_databases.postgres"], - db_args={ - **self.config["database_opts"], - **self.config["plugin_databases.postgres_opts"], - }, - ) - if plugin_db.scheme != Scheme.POSTGRES: - self.log.critical("The plugin postgres database URL must be a postgres database") - sys.exit(24) - assert isinstance(plugin_db, PostgresDatabase) - self.plugin_postgres_db = plugin_db - else: - self.plugin_postgres_db = None - - def prepare(self) -> None: - super().prepare() - - if self.config["api_features.log"]: - self.prepare_log_websocket() - - init_zip_loader(self.config) - self.prepare_db() - Client.init_cls(self) - PluginInstance.init_cls(self) - management_api = init_mgmt_api(self.config, self.loop) - self.server = MaubotServer(management_api, self.config, self.loop) - self.state_store = PgStateStore(self.db) - - async def start_db(self) -> None: - self.log.debug("Starting database...") - ignore_unsupported = self.args.ignore_unsupported_database - self.db.upgrade_table.allow_unsupported = ignore_unsupported - self.state_store.upgrade_table.allow_unsupported = ignore_unsupported - try: - await self.db.start() - await self.state_store.upgrade_table.upgrade(self.db) - if self.plugin_postgres_db and self.plugin_postgres_db is not self.db: - await self.plugin_postgres_db.start() - if self.crypto_db: - PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported - if self.crypto_db is not self.db: - await self.crypto_db.start() - else: - await PgCryptoStore.upgrade_table.upgrade(self.db) - except DatabaseException as e: - self.log.critical("Failed to initialize database", exc_info=e) - if e.explanation: - self.log.info(e.explanation) - sys.exit(25) - - async def system_exit(self) -> None: - if hasattr(self, "db"): - self.log.trace("Stopping database due to SystemExit") - await self.db.stop() - - async def start(self) -> None: - await self.start_db() - await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()]) - await asyncio.gather(*[client.start() async for client in Client.all()]) - await super().start() - async for plugin in PluginInstance.all(): - await plugin.load() - await self.server.start() - - async def stop(self) -> None: - self.add_shutdown_actions(*(client.stop() for client in Client.cache.values())) - await super().stop() - self.log.debug("Stopping server") - try: - await asyncio.wait_for(self.server.stop(), 5) - except asyncio.TimeoutError: - self.log.warning("Stopping server timed out") - await self.db.stop() - - -Maubot().run() + log.info("Starting server") + loop.run_until_complete(server.start()) + if Client.crypto_db: + log.debug("Starting client crypto database") + loop.run_until_complete(Client.crypto_db.start()) + log.info("Starting clients and plugins") + loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) + log.info("Startup actions complete, running forever") + loop.run_forever() +except KeyboardInterrupt: + log.info("Interrupt received, stopping clients") + loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()])) + if stop_log_listener is not None: + log.debug("Closing websockets") + loop.run_until_complete(stop_log_listener()) + log.debug("Stopping server") + try: + loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop)) + except asyncio.TimeoutError: + log.warning("Stopping server timed out") + log.debug("Closing event loop") + loop.close() + log.debug("Everything stopped, shutting down") + sys.exit(0) diff --git a/maubot/__meta__.py b/maubot/__meta__.py index 7225152..b3f4756 100644 --- a/maubot/__meta__.py +++ b/maubot/__meta__.py @@ -1 +1 @@ -__version__ = "0.5.2" +__version__ = "0.1.2" diff --git a/maubot/cli/__main__.py b/maubot/cli/__main__.py index 3bdbe0e..1ffd665 100644 --- a/maubot/cli/__main__.py +++ b/maubot/cli/__main__.py @@ -1,3 +1,2 @@ from . import app - app() diff --git a/maubot/cli/base.py b/maubot/cli/base.py index b35db53..1aaeec8 100644 --- a/maubot/cli/base.py +++ b/maubot/cli/base.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by diff --git a/maubot/cli/cliq/__init__.py b/maubot/cli/cliq/__init__.py index 10ede9f..cba14a4 100644 --- a/maubot/cli/cliq/__init__.py +++ b/maubot/cli/cliq/__init__.py @@ -1,2 +1,2 @@ from .cliq import command, option -from .validators import PathValidator, SPDXValidator, VersionValidator +from .validators import SPDXValidator, VersionValidator, PathValidator diff --git a/maubot/cli/cliq/cliq.py b/maubot/cli/cliq/cliq.py index 2883441..973587a 100644 --- a/maubot/cli/cliq/cliq.py +++ b/maubot/cli/cliq/cliq.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,55 +13,20 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from typing import Any, Callable -import asyncio +from typing import Any, Callable, Union, Optional import functools -import inspect -import traceback -from colorama import Fore from prompt_toolkit.validation import Validator -from questionary import prompt -import aiohttp +from PyInquirer import prompt import click from ..base import app -from ..config import get_token -from .validators import ClickValidator, Required - - -def with_http(func): - @functools.wraps(func) - async def wrapper(*args, **kwargs): - async with aiohttp.ClientSession() as sess: - try: - return await func(*args, sess=sess, **kwargs) - except aiohttp.ClientError as e: - print(f"{Fore.RED}Connection error: {e}{Fore.RESET}") - - return wrapper - - -def with_authenticated_http(func): - @functools.wraps(func) - async def wrapper(*args, server: str, **kwargs): - server, token = get_token(server) - if not token: - return - async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess: - try: - return await func(*args, sess=sess, server=server, **kwargs) - except aiohttp.ClientError as e: - print(f"{Fore.RED}Connection error: {e}{Fore.RESET}") - - return wrapper +from .validators import Required, ClickValidator def command(help: str) -> Callable[[Callable], Callable]: def decorator(func) -> Callable: - questions = getattr(func, "__inquirer_questions__", {}).copy() + questions = func.__inquirer_questions__.copy() @functools.wraps(func) def wrapper(*args, **kwargs): @@ -74,11 +39,6 @@ def command(help: str) -> Callable[[Callable], Callable]: required_unless = questions[key].pop("required_unless") if isinstance(required_unless, str) and kwargs[required_unless]: questions.pop(key) - elif isinstance(required_unless, list): - for v in required_unless: - if kwargs[v]: - questions.pop(key) - break elif isinstance(required_unless, dict): for k, v in required_unless.items(): if kwargs.get(v, object()) == v: @@ -88,25 +48,18 @@ def command(help: str) -> Callable[[Callable], Callable]: pass question_list = list(questions.values()) question_list.reverse() - resp = prompt(question_list, kbi_msg="Aborted!") + resp = prompt(question_list, keyboard_interrupt_msg="Aborted!") if not resp and question_list: return kwargs = {**kwargs, **resp} - - try: - res = func(*args, **kwargs) - if inspect.isawaitable(res): - asyncio.run(res) - except Exception: - print(Fore.RED + "Fatal error running command" + Fore.RESET) - traceback.print_exc() + func(*args, **kwargs) return app.command(help=help)(wrapper) return decorator -def yesno(val: str) -> bool | None: +def yesno(val: str) -> Optional[bool]: if not val: return None elif isinstance(val, bool): @@ -120,25 +73,14 @@ def yesno(val: str) -> bool | None: yesno.__name__ = "yes/no" -def option( - short: str, - long: str, - message: str = None, - help: str = None, - click_type: str | Callable[[str], Any] = None, - inq_type: str = None, - validator: type[Validator] = None, - required: bool = False, - default: str | bool | None = None, - is_flag: bool = False, - prompt: bool = True, - required_unless: str | list | dict = None, -) -> Callable[[Callable], Callable]: +def option(short: str, long: str, message: str = None, help: str = None, + click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None, + validator: Validator = None, required: bool = False, default: str = None, + is_flag: bool = False, prompt: bool = True, required_unless: str = None + ) -> Callable[[Callable], Callable]: if not message: message = long[2].upper() + long[3:] - - if isinstance(validator, type) and issubclass(validator, ClickValidator): - click_type = validator.click_type + click_type = validator.click_type if isinstance(validator, ClickValidator) else click_type if is_flag: click_type = yesno @@ -149,9 +91,9 @@ def option( if not hasattr(func, "__inquirer_questions__"): func.__inquirer_questions__ = {} q = { - "type": ( - inq_type if isinstance(inq_type, str) else ("input" if not is_flag else "confirm") - ), + "type": (inq_type if isinstance(inq_type, str) + else ("input" if not is_flag + else "confirm")), "name": long[2:], "message": message, } @@ -160,9 +102,9 @@ def option( if default is not None: q["default"] = default if required or required_unless is not None: - q["validate"] = Required(validator) + q["validator"] = Required(validator) elif validator: - q["validate"] = validator + q["validator"] = validator func.__inquirer_questions__[long[2:]] = q return func diff --git a/maubot/cli/cliq/validators.py b/maubot/cli/cliq/validators.py index 46d3c92..9a57914 100644 --- a/maubot/cli/cliq/validators.py +++ b/maubot/cli/cliq/validators.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -16,9 +16,9 @@ from typing import Callable import os -from packaging.version import InvalidVersion, Version +from packaging.version import Version, InvalidVersion +from prompt_toolkit.validation import Validator, ValidationError from prompt_toolkit.document import Document -from prompt_toolkit.validation import ValidationError, Validator import click from ..util import spdx as spdxlib diff --git a/maubot/cli/commands/__init__.py b/maubot/cli/commands/__init__.py index 145646b..c535234 100644 --- a/maubot/cli/commands/__init__.py +++ b/maubot/cli/commands/__init__.py @@ -1 +1 @@ -from . import auth, build, init, login, logs, upload +from . import upload, build, login, init, logs, auth diff --git a/maubot/cli/commands/auth.py b/maubot/cli/commands/auth.py index 64b1dc7..d8cb3cb 100644 --- a/maubot/cli/commands/auth.py +++ b/maubot/cli/commands/auth.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,154 +13,75 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from urllib.parse import quote +from urllib.request import urlopen, Request +from urllib.error import HTTPError +import functools import json -import webbrowser from colorama import Fore -from yarl import URL -import aiohttp import click +from ..config import get_token from ..cliq import cliq history_count: int = 10 +enc = functools.partial(quote, safe="") + friendly_errors = { - "server_not_found": ( - "Registration target server not found.\n\n" - "To log in or register through maubot, you must add the server to the\n" - "homeservers section in the config. If you only want to log in,\n" - "leave the `secret` field empty." - ), - "registration_no_sso": ( - "The register operation is only for registering with a password.\n\n" - "To register with SSO, simply leave out the --register flag." - ), + "server_not_found": "Registration target server not found.\n\n" + "To log in or register through maubot, you must add the server to the\n" + "registration_secrets section in the config. If you only want to log in,\n" + "leave the `secret` field empty." } -async def list_servers(server: str, sess: aiohttp.ClientSession) -> None: - url = URL(server) / "_matrix/maubot/v1/client/auth/servers" - async with sess.get(url) as resp: - data = await resp.json() - print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}") - for server in data.keys(): - print(f"* {Fore.CYAN}{server}{Fore.RESET}") - - @cliq.command(help="Log into a Matrix account via the Maubot server") @cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list") -@cliq.option( - "-u", "--username", help="The username to log in with", required_unless=["list", "sso"] -) -@cliq.option( - "-p", - "--password", - help="The password to log in with", - inq_type="password", - required_unless=["list", "sso"], -) -@cliq.option( - "-s", - "--server", - help="The maubot instance to log in through", - default="", - required=False, - prompt=False, -) -@click.option( - "-r", "--register", help="Register instead of logging in", is_flag=True, default=False -) -@click.option( - "-c", - "--update-client", - help="Instead of returning the access token, create or update a client in maubot using it", - is_flag=True, - default=False, -) +@cliq.option("-u", "--username", help="The username to log in with", required_unless="list") +@cliq.option("-p", "--password", help="The password to log in with", inq_type="password", + required_unless="list") +@cliq.option("-s", "--server", help="The maubot instance to log in through", default="", + required=False, prompt=False) +@click.option("-r", "--register", help="Register instead of logging in", is_flag=True, + default=False) @click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False) -@click.option( - "-o", "--sso", help="Use single sign-on instead of password login", is_flag=True, default=False -) -@click.option( - "-n", - "--device-name", - help="The initial e2ee device displayname (only for login)", - default="Maubot", - required=False, -) -@cliq.with_authenticated_http -async def auth( - homeserver: str, - username: str, - password: str, - server: str, - register: bool, - list: bool, - update_client: bool, - device_name: str, - sso: bool, - sess: aiohttp.ClientSession, -) -> None: - if list: - await list_servers(server, sess) +def auth(homeserver: str, username: str, password: str, server: str, register: bool, list: bool + ) -> None: + server, token = get_token(server) + if not token: return + headers = {"Authorization": f"Bearer {token}"} + if list: + url = f"{server}/_matrix/maubot/v1/client/auth/servers" + with urlopen(Request(url, headers=headers)) as resp_data: + resp = json.load(resp_data) + print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}") + for server in resp.keys(): + print(f"* {Fore.CYAN}{server}{Fore.RESET}") + return endpoint = "register" if register else "login" - url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint - if update_client: - url = url.update_query({"update_client": "true"}) - if sso: - url = url.update_query({"sso": "true"}) - req_data = {"device_name": device_name} - else: - req_data = {"username": username, "password": password, "device_name": device_name} - - async with sess.post(url, json=req_data) as resp: - if not 200 <= resp.status < 300: - await print_error(resp, is_register=register) - elif sso: - await wait_sso(resp, sess, server, homeserver) - else: - await print_response(resp, is_register=register) - - -async def wait_sso( - resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession, server: str, homeserver: str -) -> None: - data = await resp.json() - sso_url, reg_id = data["sso_url"], data["id"] - print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}") - webbrowser.open(sso_url, autoraise=True) - print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}") - wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait" - async with sess.post(wait_url, json={}) as resp: - await print_response(resp, is_register=False) - - -async def print_response(resp: aiohttp.ClientResponse, is_register: bool) -> None: - if resp.status == 200: - data = await resp.json() - action = "registered" if is_register else "logged in as" - print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.") - print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}") - print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}") - elif resp.status in (201, 202): - data = await resp.json() - action = "created" if resp.status == 201 else "updated" - print( - f"{Fore.GREEN}Successfully {action} client for " - f"{Fore.CYAN}{data['id']}{Fore.GREEN} / " - f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}" - ) - else: - await print_error(resp, is_register) - - -async def print_error(resp: aiohttp.ClientResponse, is_register: bool) -> None: + headers["Content-Type"] = "application/json" + url = f"{server}/_matrix/maubot/v1/client/auth/{enc(homeserver)}/{endpoint}" + req = Request(url, headers=headers, + data=json.dumps({ + "username": username, + "password": password, + }).encode("utf-8")) try: - err_data = await resp.json() - error = friendly_errors.get(err_data["errcode"], err_data["error"]) - except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): - error = await resp.text() - action = "register" if is_register else "log in" - print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") + with urlopen(req) as resp_data: + resp = json.load(resp_data) + action = "registered" if register else "logged in as" + print(f"{Fore.GREEN}Successfully {action} " + f"{Fore.CYAN}{resp['user_id']}{Fore.GREEN}.") + print(f"{Fore.GREEN}Access token: {Fore.CYAN}{resp['access_token']}{Fore.RESET}") + print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{resp['device_id']}{Fore.RESET}") + except HTTPError as e: + try: + err_data = json.load(e) + error = friendly_errors.get(err_data["errcode"], err_data["error"]) + except (json.JSONDecodeError, KeyError): + error = str(e) + action = "register" if register else "log in" + print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") diff --git a/maubot/cli/commands/build.py b/maubot/cli/commands/build.py index 39eca53..a7b2fe1 100644 --- a/maubot/cli/commands/build.py +++ b/maubot/cli/commands/build.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,27 +13,22 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from typing import IO +from typing import Optional, Union, IO from io import BytesIO -import asyncio +import zipfile import glob import os -import zipfile -from aiohttp import ClientSession -from colorama import Fore -from questionary import prompt from ruamel.yaml import YAML, YAMLError +from colorama import Fore +from PyInquirer import prompt import click from mautrix.types import SerializerError from ...loader import PluginMeta -from ..base import app -from ..cliq import cliq from ..cliq.validators import PathValidator +from ..base import app from ..config import get_token from .upload import upload_file @@ -46,7 +41,7 @@ def zipdir(zip, dir): zip.write(os.path.join(root, file)) -def read_meta(path: str) -> PluginMeta | None: +def read_meta(path: str) -> Optional[PluginMeta]: try: with open(os.path.join(path, "maubot.yaml")) as meta_file: try: @@ -67,7 +62,7 @@ def read_meta(path: str) -> PluginMeta | None: return meta -def read_output_path(output: str, meta: PluginMeta) -> str | None: +def read_output_path(output: str, meta: PluginMeta) -> Optional[str]: directory = os.getcwd() filename = f"{meta.id}-v{meta.version}.mbp" if not output: @@ -75,15 +70,18 @@ def read_output_path(output: str, meta: PluginMeta) -> str | None: elif os.path.isdir(output): output = os.path.join(output, filename) elif os.path.exists(output): - q = [{"type": "confirm", "name": "override", "message": f"{output} exists, override?"}] - override = prompt(q)["override"] + override = prompt({ + "type": "confirm", + "name": "override", + "message": f"{output} exists, override?" + })["override"] if not override: return None os.remove(output) return os.path.abspath(output) -def write_plugin(meta: PluginMeta, output: str | IO) -> None: +def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None: with zipfile.ZipFile(output, "w") as zip: meta_dump = BytesIO() yaml.dump(meta.serialize(), meta_dump) @@ -93,47 +91,34 @@ def write_plugin(meta: PluginMeta, output: str | IO) -> None: if os.path.isfile(f"{module}.py"): zip.write(f"{module}.py") elif module is not None and os.path.isdir(module): - if os.path.isfile(f"{module}/__init__.py"): - zipdir(zip, module) - else: - print( - Fore.YELLOW - + f"Module {module} is missing __init__.py, skipping" - + Fore.RESET - ) + zipdir(zip, module) else: print(Fore.YELLOW + f"Module {module} not found, skipping" + Fore.RESET) + for pattern in meta.extra_files: for file in glob.iglob(pattern): zip.write(file) -@cliq.with_authenticated_http -async def upload_plugin(output: str | IO, *, server: str, sess: ClientSession) -> None: +def upload_plugin(output: Union[str, IO], server: str) -> None: server, token = get_token(server) if not token: return if isinstance(output, str): with open(output, "rb") as file: - await upload_file(sess, file, server) + upload_file(file, server, token) else: - await upload_file(sess, output, server) + upload_file(output, server, token) -@app.command( - short_help="Build a maubot plugin", - help=( - "Build a maubot plugin. First parameter is the path to root of the plugin " - "to build. You can also use --output to specify output file." - ), -) +@app.command(short_help="Build a maubot plugin", + help="Build a maubot plugin. First parameter is the path to root of the plugin " + "to build. You can also use --output to specify output file.") @click.argument("path", default=os.getcwd()) -@click.option( - "-o", "--output", help="Path to output built plugin to", type=PathValidator.click_type -) -@click.option( - "-u", "--upload", help="Upload plugin to server after building", is_flag=True, default=False -) +@click.option("-o", "--output", help="Path to output built plugin to", + type=PathValidator.click_type) +@click.option("-u", "--upload", help="Upload plugin to server after building", is_flag=True, + default=False) @click.option("-s", "--server", help="Server to upload built plugin to") def build(path: str, output: str, upload: bool, server: str) -> None: meta = read_meta(path) @@ -152,4 +137,4 @@ def build(path: str, output: str, upload: bool, server: str) -> None: else: output.seek(0) if upload: - asyncio.run(upload_plugin(output, server=server)) + upload_plugin(output, server) diff --git a/maubot/cli/commands/init.py b/maubot/cli/commands/init.py index d24def9..7372a2d 100644 --- a/maubot/cli/commands/init.py +++ b/maubot/cli/commands/init.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,11 +13,11 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from pkg_resources import resource_string import os -from jinja2 import Template from packaging.version import Version -from pkg_resources import resource_string +from jinja2 import Template from .. import cliq from ..cliq import SPDXValidator, VersionValidator @@ -40,55 +40,26 @@ def load_templates(): @cliq.command(help="Initialize a new maubot plugin") -@cliq.option( - "-n", - "--name", - help="The name of the project", - required=True, - default=os.path.basename(os.getcwd()), -) -@cliq.option( - "-i", - "--id", - message="ID", - required=True, - help="The maubot plugin ID (Java package name format)", -) -@cliq.option( - "-v", - "--version", - help="Initial version for project (PEP-440 format)", - default="0.1.0", - validator=VersionValidator, - required=True, -) -@cliq.option( - "-l", - "--license", - validator=SPDXValidator, - default="AGPL-3.0-or-later", - help="The license for the project (SPDX identifier)", - required=False, -) -@cliq.option( - "-c", - "--config", - message="Should the plugin include a config?", - help="Include a config in the plugin stub", - default=False, - is_flag=True, -) +@cliq.option("-n", "--name", help="The name of the project", required=True, + default=os.path.basename(os.getcwd())) +@cliq.option("-i", "--id", message="ID", required=True, + help="The maubot plugin ID (Java package name format)") +@cliq.option("-v", "--version", help="Initial version for project (PEP-440 format)", + default="0.1.0", validator=VersionValidator, required=True) +@cliq.option("-l", "--license", validator=SPDXValidator, default="AGPL-3.0-or-later", + help="The license for the project (SPDX identifier)", required=False) +@cliq.option("-c", "--config", message="Should the plugin include a config?", + help="Include a config in the plugin stub", default=False, is_flag=True) def init(name: str, id: str, version: Version, license: str, config: bool) -> None: load_templates() main_class = name[0].upper() + name[1:] - meta = meta_template.render( - id=id, version=str(version), license=license, config=config, main_class=main_class - ) + meta = meta_template.render(id=id, version=str(version), license=license, config=config, + main_class=main_class) with open("maubot.yaml", "w") as file: file.write(meta) if license: with open("LICENSE", "w") as file: - file.write(spdx.get(license)["licenseText"]) + file.write(spdx.get(license)["text"]) if not os.path.isdir(name): os.mkdir(name) mod = mod_template.render(config=config, name=main_class) diff --git a/maubot/cli/commands/login.py b/maubot/cli/commands/login.py index 8aac0f5..fdf71b3 100644 --- a/maubot/cli/commands/login.py +++ b/maubot/cli/commands/login.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,55 +13,32 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from urllib.request import urlopen +from urllib.error import HTTPError import json import os from colorama import Fore -from yarl import URL -import aiohttp +from ..config import save_config, config from ..cliq import cliq -from ..config import config, save_config @cliq.command(help="Log in to a Maubot instance") -@cliq.option( - "-u", - "--username", - help="The username of your account", - default=os.environ.get("USER", None), - required=True, -) -@cliq.option( - "-p", "--password", help="The password to your account", inq_type="password", required=True -) -@cliq.option( - "-s", - "--server", - help="The server to log in to", - default="http://localhost:29316", - required=True, -) -@cliq.option( - "-a", - "--alias", - help="Alias to reference the server without typing the full URL", - default="", - required=False, -) -@cliq.with_http -async def login( - server: str, username: str, password: str, alias: str, sess: aiohttp.ClientSession -) -> None: +@cliq.option("-u", "--username", help="The username of your account", default=os.environ.get("USER", None), required=True) +@cliq.option("-p", "--password", help="The password to your account", inq_type="password", required=True) +@cliq.option("-s", "--server", help="The server to log in to", default="http://localhost:29316", required=True) +@cliq.option("-a", "--alias", help="Alias to reference the server without typing the full URL", default="", required=False) +def login(server, username, password, alias) -> None: data = { "username": username, "password": password, } - url = URL(server) / "_matrix/maubot/v1/auth/login" - async with sess.post(url, json=data) as resp: - if resp.status == 200: - data = await resp.json() - config["servers"][server] = data["token"] + try: + with urlopen(f"{server}/_matrix/maubot/v1/auth/login", + data=json.dumps(data).encode("utf-8")) as resp_data: + resp = json.load(resp_data) + config["servers"][server] = resp["token"] if not config["default_server"]: print(Fore.CYAN, "Setting", server, "as the default server") config["default_server"] = server @@ -69,9 +46,9 @@ async def login( config["aliases"][alias] = server save_config() print(Fore.GREEN + "Logged in successfully") - else: - try: - err = (await resp.json())["error"] - except (json.JSONDecodeError, KeyError): - err = await resp.text() - print(Fore.RED + err + Fore.RESET) + except HTTPError as e: + try: + err = json.load(e) + except json.JSONDecodeError: + err = {} + print(Fore.RED + err.get("error", str(e)) + Fore.RESET) diff --git a/maubot/cli/commands/logs.py b/maubot/cli/commands/logs.py index e0ed07d..8d0a578 100644 --- a/maubot/cli/commands/logs.py +++ b/maubot/cli/commands/logs.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -16,14 +16,14 @@ from datetime import datetime import asyncio -from aiohttp import ClientSession, WSMessage, WSMsgType from colorama import Fore +from aiohttp import WSMsgType, WSMessage, ClientSession import click from mautrix.types import Obj -from ..base import app from ..config import get_token +from ..base import app history_count: int = 10 @@ -38,13 +38,19 @@ def logs(server: str, tail: int) -> None: global history_count history_count = tail loop = asyncio.get_event_loop() - loop.run_until_complete(view_logs(server, token)) + future = asyncio.ensure_future(view_logs(server, token), loop=loop) + try: + loop.run_until_complete(future) + except KeyboardInterrupt: + future.cancel() + loop.run_until_complete(future) + loop.close() def parsedate(entry: Obj) -> None: i = entry.time.index("+") i = entry.time.index(":", i) - entry.time = entry.time[:i] + entry.time[i + 1 :] + entry.time = entry.time[:i] + entry.time[i + 1:] entry.time = datetime.strptime(entry.time, "%Y-%m-%dT%H:%M:%S.%f%z") @@ -60,16 +66,13 @@ levelcolors = { def print_entry(entry: dict) -> None: entry = Obj(**entry) parsedate(entry) - print( - "{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}".format( - date=entry.time.strftime("%Y-%m-%d %H:%M:%S"), - level=entry.levelname, - levelcolor=levelcolors.get(entry.levelname, ""), - resetcolor=Fore.RESET, - logger=entry.name, - message=entry.msg, - ) - ) + print("{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}" + .format(date=entry.time.strftime("%Y-%m-%d %H:%M:%S"), + level=entry.levelname, + levelcolor=levelcolors.get(entry.levelname, ""), + resetcolor=Fore.RESET, + logger=entry.name, + message=entry.msg)) if entry.exc_info: print(entry.exc_info) diff --git a/maubot/cli/commands/upload.py b/maubot/cli/commands/upload.py index 3c2cf1e..cb5b4b5 100644 --- a/maubot/cli/commands/upload.py +++ b/maubot/cli/commands/upload.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,46 +13,45 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from urllib.request import urlopen, Request +from urllib.error import HTTPError from typing import IO import json from colorama import Fore -from yarl import URL -import aiohttp import click -from ..cliq import cliq +from ..base import app +from ..config import get_default_server, get_token class UploadError(Exception): pass -@cliq.command(help="Upload a maubot plugin") +@app.command(help="Upload a maubot plugin") @click.argument("path") @click.option("-s", "--server", help="The maubot instance to upload the plugin to") -@cliq.with_authenticated_http -async def upload(path: str, server: str, sess: aiohttp.ClientSession) -> None: +def upload(path: str, server: str) -> None: + server, token = get_token(server) + if not token: + return with open(path, "rb") as file: - await upload_file(sess, file, server) + upload_file(file, server, token) -async def upload_file(sess: aiohttp.ClientSession, file: IO, server: str) -> None: - url = (URL(server) / "_matrix/maubot/v1/plugins/upload").with_query({"allow_override": "true"}) - headers = {"Content-Type": "application/zip"} - async with sess.post(url, data=file, headers=headers) as resp: - if resp.status in (200, 201): - data = await resp.json() - print( - f"{Fore.GREEN}Plugin {Fore.CYAN}{data['id']} v{data['version']}{Fore.GREEN} " - f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}" - ) - else: - try: - err = await resp.json() - if "stacktrace" in err: - print(err["stacktrace"]) - err = err["error"] - except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): - err = await resp.text() - print(f"{Fore.RED}Failed to upload plugin: {err}{Fore.RESET}") +def upload_file(file: IO, server: str, token: str) -> None: + req = Request(f"{server}/_matrix/maubot/v1/plugins/upload?allow_override=true", data=file, + headers={"Authorization": f"Bearer {token}", "Content-Type": "application/zip"}) + try: + with urlopen(req) as resp_data: + resp = json.load(resp_data) + print(f"{Fore.GREEN}Plugin {Fore.CYAN}{resp['id']} v{resp['version']}{Fore.GREEN} " + f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}") + except HTTPError as e: + try: + err = json.load(e) + except json.JSONDecodeError: + err = {} + print(err.get("stacktrace", "")) + print(Fore.RED + "Failed to upload plugin: " + err.get("error", str(e)) + Fore.RESET) diff --git a/maubot/cli/config.py b/maubot/cli/config.py index 5fdc4ea..550c326 100644 --- a/maubot/cli/config.py +++ b/maubot/cli/config.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,15 +13,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from typing import Any +from typing import Tuple, Optional, Dict, Any import json import os from colorama import Fore -config: dict[str, Any] = { +config: Dict[str, Any] = { "servers": {}, "aliases": {}, "default_server": None, @@ -29,19 +27,18 @@ config: dict[str, Any] = { configdir = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.environ.get("HOME"), ".config")) -def get_default_server() -> tuple[str | None, str | None]: +def get_default_server() -> Tuple[Optional[str], Optional[str]]: try: - server: str < None = config["default_server"] + server: Optional[str] = config["default_server"] except KeyError: server = None if server is None: print(f"{Fore.RED}Default server not configured.{Fore.RESET}") - print(f"Perhaps you forgot to {Fore.CYAN}mbc login{Fore.RESET}?") return None, None return server, _get_token(server) -def get_token(server: str) -> tuple[str | None, str | None]: +def get_token(server: str) -> Tuple[Optional[str], Optional[str]]: if not server: return get_default_server() if server in config["aliases"]: @@ -49,14 +46,14 @@ def get_token(server: str) -> tuple[str | None, str | None]: return server, _get_token(server) -def _resolve_alias(alias: str) -> str | None: +def _resolve_alias(alias: str) -> Optional[str]: try: return config["aliases"][alias] except KeyError: return None -def _get_token(server: str) -> str | None: +def _get_token(server: str) -> Optional[str]: try: return config["servers"][server] except KeyError: diff --git a/maubot/cli/res/spdx.json.zip b/maubot/cli/res/spdx.json.zip index 98de1b0..4cd4701 100644 Binary files a/maubot/cli/res/spdx.json.zip and b/maubot/cli/res/spdx.json.zip differ diff --git a/maubot/cli/util/spdx.py b/maubot/cli/util/spdx.py index 69f58b7..aca303d 100644 --- a/maubot/cli/util/spdx.py +++ b/maubot/cli/util/spdx.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,14 +13,12 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -import json +from typing import Dict, Optional import zipfile - import pkg_resources +import json -spdx_list: dict[str, dict[str, str]] | None = None +spdx_list: Optional[Dict[str, Dict[str, str]]] = None def load() -> None: @@ -33,13 +31,13 @@ def load() -> None: spdx_list = json.load(file) -def get(id: str) -> dict[str, str]: +def get(id: str) -> Dict[str, str]: if not spdx_list: load() - return spdx_list[id] + return spdx_list[id.lower()] def valid(id: str) -> bool: if not spdx_list: load() - return id in spdx_list + return id.lower() in spdx_list diff --git a/maubot/client.py b/maubot/client.py index b0fde73..9e3d1a7 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,141 +13,87 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, cast -from collections import defaultdict +from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING +from os import path import asyncio import logging from aiohttp import ClientSession +from yarl import URL +from mautrix.errors import MatrixInvalidToken, MatrixRequestError +from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, + StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter, + PresenceState, StateFilter) from mautrix.client import InternalEventType -from mautrix.errors import MatrixInvalidToken -from mautrix.types import ( - ContentURI, - DeviceID, - EventFilter, - EventType, - Filter, - FilterID, - Membership, - PresenceState, - RoomEventFilter, - RoomFilter, - StateEvent, - StateFilter, - StrippedStateEvent, - SyncToken, - UserID, -) -from mautrix.util import background_task -from mautrix.util.async_getter_lock import async_getter_lock -from mautrix.util.logging import TraceLogger +from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore -from .db import Client as DBClient +from .lib.store_proxy import SyncStoreProxy +from .db import DBClient from .matrix import MaubotMatrixClient try: - from mautrix.crypto import OlmMachine, PgCryptoStore + from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore, + PickleCryptoStore) - crypto_import_error = None -except ImportError as e: - OlmMachine = PgCryptoStore = None - crypto_import_error = e + + class SQLStateStore(BaseSQLStateStore, CryptoStateStore): + pass +except ImportError: + OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None + SQLStateStore = BaseSQLStateStore + +try: + from mautrix.util.async_db import Database as AsyncDatabase + from mautrix.crypto import PgCryptoStore +except ImportError: + AsyncDatabase = None + PgCryptoStore = None if TYPE_CHECKING: - from .__main__ import Maubot from .instance import PluginInstance + from .config import Config + +log = logging.getLogger("maubot.client") -class Client(DBClient): - maubot: "Maubot" = None - cache: dict[UserID, Client] = {} - _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) - log: TraceLogger = logging.getLogger("maubot.client") - +class Client: + log: logging.Logger = None + loop: asyncio.AbstractEventLoop = None + cache: Dict[UserID, 'Client'] = {} http_client: ClientSession = None + global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore() + crypto_pickle_dir: str = None + crypto_db: 'AsyncDatabase' = None - references: set[PluginInstance] + references: Set['PluginInstance'] + db_instance: DBClient client: MaubotMatrixClient - crypto: OlmMachine | None - crypto_store: PgCryptoStore | None + crypto: Optional['OlmMachine'] + crypto_store: Optional['CryptoStore'] started: bool - sync_ok: bool - remote_displayname: str | None - remote_avatar_url: ContentURI | None + remote_displayname: Optional[str] + remote_avatar_url: Optional[ContentURI] - def __init__( - self, - id: UserID, - homeserver: str, - access_token: str, - device_id: DeviceID, - enabled: bool = False, - next_batch: SyncToken = "", - filter_id: FilterID = "", - sync: bool = True, - autojoin: bool = True, - online: bool = True, - displayname: str = "disable", - avatar_url: str = "disable", - ) -> None: - super().__init__( - id=id, - homeserver=homeserver, - access_token=access_token, - device_id=device_id, - enabled=bool(enabled), - next_batch=next_batch, - filter_id=filter_id, - sync=bool(sync), - autojoin=bool(autojoin), - online=bool(online), - displayname=displayname, - avatar_url=avatar_url, - ) - self._postinited = False - - def __hash__(self) -> int: - return hash(self.id) - - @classmethod - def init_cls(cls, maubot: "Maubot") -> None: - cls.maubot = maubot - - def _make_client( - self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None - ) -> MaubotMatrixClient: - return MaubotMatrixClient( - mxid=self.id, - base_url=homeserver or self.homeserver, - token=token or self.access_token, - client_session=self.http_client, - log=self.log, - crypto_log=self.log.getChild("crypto"), - loop=self.maubot.loop, - device_id=device_id or self.device_id, - sync_store=self, - state_store=self.maubot.state_store, - ) - - def postinit(self) -> None: - if self._postinited: - raise RuntimeError("postinit() called twice") - self._postinited = True + def __init__(self, db_instance: DBClient) -> None: + self.db_instance = db_instance self.cache[self.id] = self - self.log = self.log.getChild(self.id) - self.http_client = ClientSession(loop=self.maubot.loop) + self.log = log.getChild(self.id) self.references = set() self.started = False self.sync_ok = True self.remote_displayname = None self.remote_avatar_url = None - self.client = self._make_client() - if self.enable_crypto: - self._prepare_crypto() + self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver, + token=self.access_token, client_session=self.http_client, + log=self.log, loop=self.loop, device_id=self.device_id, + sync_store=SyncStoreProxy(self.db_instance), + state_store=self.global_state_store) + if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir): + self.crypto_store = self._make_crypto_store() + self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) + self.client.crypto = self.crypto else: self.crypto_store = None self.crypto = None @@ -160,56 +106,21 @@ class Client(DBClient): self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) - def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]: - async def handler(data: dict[str, Any]) -> None: + def _make_crypto_store(self) -> 'CryptoStore': + if self.crypto_db: + return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db) + elif self.crypto_pickle_dir: + return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto", + path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle")) + raise ValueError("Crypto database not configured") + + def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: + async def handler(data: Dict[str, Any]) -> None: self.sync_ok = ok return handler - @property - def enable_crypto(self) -> bool: - if not self.device_id: - return False - elif not OlmMachine: - global crypto_import_error - self.log.warning( - "Client has device ID, but encryption dependencies not installed", - exc_info=crypto_import_error, - ) - # Clear the stack trace after it's logged once to avoid spamming logs - crypto_import_error = None - return False - elif not self.maubot.crypto_db: - self.log.warning("Client has device ID, but crypto database is not prepared") - return False - return True - - def _prepare_crypto(self) -> None: - self.crypto_store = PgCryptoStore( - account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db - ) - self.crypto = OlmMachine( - self.client, - self.crypto_store, - self.maubot.state_store, - log=self.client.crypto_log, - ) - self.client.crypto = self.crypto - - def _remove_crypto_event_handlers(self) -> None: - if not self.crypto: - return - handlers = [ - (InternalEventType.DEVICE_OTK_COUNT, self.crypto.handle_otk_count), - (InternalEventType.DEVICE_LISTS, self.crypto.handle_device_lists), - (EventType.TO_DEVICE_ENCRYPTED, self.crypto.handle_to_device_event), - (EventType.ROOM_KEY_REQUEST, self.crypto.handle_room_key_request), - (EventType.ROOM_MEMBER, self.crypto.handle_member_event), - ] - for event_type, func in handlers: - self.client.remove_event_handler(event_type, func) - - async def start(self, try_n: int | None = 0) -> None: + async def start(self, try_n: Optional[int] = 0) -> None: try: if try_n > 0: await asyncio.sleep(try_n * 10) @@ -217,21 +128,7 @@ class Client(DBClient): except Exception: self.log.exception("Failed to start") - async def _start_crypto(self) -> None: - self.log.debug("Enabling end-to-end encryption support") - await self.crypto_store.open() - crypto_device_id = await self.crypto_store.get_device_id() - if crypto_device_id and crypto_device_id != self.device_id: - self.log.warning( - "Mismatching device ID in crypto store and main database, resetting encryption" - ) - await self.crypto_store.delete() - crypto_device_id = None - await self.crypto.load() - if not crypto_device_id: - await self.crypto_store.put_device_id(self.device_id) - - async def _start(self, try_n: int | None = 0) -> None: + async def _start(self, try_n: Optional[int] = 0) -> None: if not self.enabled: self.log.debug("Not starting disabled client") return @@ -239,60 +136,53 @@ class Client(DBClient): self.log.warning("Ignoring start() call to started client") return try: - await self.client.versions() - whoami = await self.client.whoami() + user_id = await self.client.whoami() except MatrixInvalidToken as e: self.log.error(f"Invalid token: {e}. Disabling client") - self.enabled = False - await self.update() + self.db_instance.enabled = False return except Exception as e: if try_n >= 8: self.log.exception("Failed to get /account/whoami, disabling client") - self.enabled = False - await self.update() + self.db_instance.enabled = False else: - self.log.warning( - f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}" - ) - background_task.create(self.start(try_n + 1)) + self.log.warning(f"Failed to get /account/whoami, " + f"retrying in {(try_n + 1) * 10}s: {e}") + _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop) return - if whoami.user_id != self.id: - self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}") - self.enabled = False - await self.update() - return - elif whoami.device_id and self.device_id and whoami.device_id != self.device_id: - self.log.error( - f"Device ID mismatch: expected {self.device_id}, but got {whoami.device_id}" - ) - self.enabled = False - await self.update() + if user_id != self.id: + self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}") + self.db_instance.enabled = False return if not self.filter_id: - self.filter_id = await self.client.create_filter( - Filter( - room=RoomFilter( - timeline=RoomEventFilter( - limit=50, - lazy_load_members=True, - ), - state=StateFilter( - lazy_load_members=True, - ), + self.db_instance.edit(filter_id=await self.client.create_filter(Filter( + room=RoomFilter( + timeline=RoomEventFilter( + limit=50, + lazy_load_members=True, ), - presence=EventFilter( - not_types=[EventType.PRESENCE], - ), - ) - ) - await self.update() + state=StateFilter( + lazy_load_members=True, + ) + ), + presence=EventFilter( + not_types=[EventType.PRESENCE], + ), + ))) if self.displayname != "disable": await self.client.set_displayname(self.displayname) if self.avatar_url != "disable": await self.client.set_avatar_url(self.avatar_url) if self.crypto: - await self._start_crypto() + self.log.debug("Enabling end-to-end encryption support") + await self.crypto_store.open() + crypto_device_id = await self.crypto_store.get_device_id() + if crypto_device_id and crypto_device_id != self.device_id: + self.log.warning("Mismatching device ID in crypto store and main database. " + "Encryption may not work.") + await self.crypto.load() + if not crypto_device_id: + await self.crypto_store.put_device_id(self.device_id) self.start_sync() await self._update_remote_profile() self.started = True @@ -320,22 +210,24 @@ class Client(DBClient): if self.crypto: await self.crypto_store.close() - async def clear_cache(self) -> None: + def clear_cache(self) -> None: self.stop_sync() - self.filter_id = FilterID("") - self.next_batch = SyncToken("") - await self.update() + self.db_instance.edit(filter_id="", next_batch="") self.start_sync() + def delete(self) -> None: + try: + del self.cache[self.id] + except KeyError: + pass + self.db_instance.delete() + def to_dict(self) -> dict: return { "id": self.id, "homeserver": self.homeserver, "access_token": self.access_token, "device_id": self.device_id, - "fingerprint": ( - self.crypto.account.fingerprint if self.crypto and self.crypto.account else None - ), "enabled": self.enabled, "started": self.started, "sync": self.sync, @@ -349,45 +241,32 @@ class Client(DBClient): "instances": [instance.to_dict() for instance in self.references], } + @classmethod + def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']: + try: + return cls.cache[user_id] + except KeyError: + db_instance = db_instance or DBClient.get(user_id) + if not db_instance: + return None + return Client(db_instance) + + @classmethod + def all(cls) -> Iterable['Client']: + return (cls.get(user.id, user) for user in DBClient.all()) + async def _handle_tombstone(self, evt: StateEvent) -> None: - if evt.state_key != "": - return if not evt.content.replacement_room: self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring") return - is_joined = await self.client.state_store.is_joined( - evt.content.replacement_room, - self.client.mxid, - ) - if is_joined: - self.log.debug( - f"Ignoring tombstone from {evt.room_id} to {evt.content.replacement_room} " - f"sent by {evt.sender}: already joined to replacement room" - ) - return - self.log.debug( - f"Following tombstone from {evt.room_id} to {evt.content.replacement_room} " - f"sent by {evt.sender}" - ) _, server = self.client.parse_user_id(evt.sender) - room_id = await self.client.join_room(evt.content.replacement_room, servers=[server]) - power_levels = await self.client.get_state_event(room_id, EventType.ROOM_POWER_LEVELS) - if power_levels.get_user_level(evt.sender) < power_levels.invite: - self.log.warning( - f"{evt.room_id} was tombstoned into {room_id} by {evt.sender}," - " but the sender doesn't have invite power levels, leaving..." - ) - await self.client.leave_room( - room_id, - f"Followed tombstone from {evt.room_id} by {evt.sender}," - " but sender doesn't have sufficient power level for invites", - ) + await self.client.join_room(evt.content.replacement_room, servers=[server]) async def _handle_invite(self, evt: StrippedStateEvent) -> None: if evt.state_key == self.id and evt.content.membership == Membership.INVITE: await self.client.join_room(evt.room_id) - async def update_started(self, started: bool | None) -> None: + async def update_started(self, started: bool) -> None: if started is None or started == self.started: return if started: @@ -395,162 +274,154 @@ class Client(DBClient): else: await self.stop() - async def update_enabled(self, enabled: bool | None, save: bool = True) -> None: - if enabled is None or enabled == self.enabled: - return - self.enabled = enabled - if save: - await self.update() - - async def update_displayname(self, displayname: str | None, save: bool = True) -> None: + async def update_displayname(self, displayname: str) -> None: if displayname is None or displayname == self.displayname: return - self.displayname = displayname + self.db_instance.displayname = displayname if self.displayname != "disable": await self.client.set_displayname(self.displayname) else: await self._update_remote_profile() - if save: - await self.update() - async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None: + async def update_avatar_url(self, avatar_url: ContentURI) -> None: if avatar_url is None or avatar_url == self.avatar_url: return - self.avatar_url = avatar_url + self.db_instance.avatar_url = avatar_url if self.avatar_url != "disable": await self.client.set_avatar_url(self.avatar_url) else: await self._update_remote_profile() - if save: - await self.update() - async def update_sync(self, sync: bool | None, save: bool = True) -> None: - if sync is None or self.sync == sync: - return - self.sync = sync - if self.started: - if sync: - self.start_sync() - else: - self.stop_sync() - if save: - await self.update() - - async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None: - if autojoin is None or autojoin == self.autojoin: - return - if autojoin: - self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) - else: - self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite) - self.autojoin = autojoin - if save: - await self.update() - - async def update_online(self, online: bool | None, save: bool = True) -> None: - if online is None or online == self.online: - return - self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE - self.online = online - if save: - await self.update() - - async def update_access_details( - self, - access_token: str | None, - homeserver: str | None, - device_id: str | None = None, - ) -> None: + async def update_access_details(self, access_token: str, homeserver: str) -> None: if not access_token and not homeserver: return - if device_id is None: - device_id = self.device_id - elif not device_id: - device_id = None - if ( - access_token == self.access_token - and homeserver == self.homeserver - and device_id == self.device_id - ): + elif access_token == self.access_token and homeserver == self.homeserver: return - new_client = self._make_client(homeserver, access_token, device_id) - whoami = await new_client.whoami() - if whoami.user_id != self.id: - raise ValueError(f"MXID mismatch: {whoami.user_id}") - elif whoami.device_id and device_id and whoami.device_id != device_id: - raise ValueError(f"Device ID mismatch: {whoami.device_id}") - new_client.sync_store = self + new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver, + token=access_token or self.access_token, loop=self.loop, + client_session=self.http_client, device_id=self.device_id, + log=self.log, state_store=self.global_state_store) + mxid = await new_client.whoami() + if mxid != self.id: + raise ValueError(f"MXID mismatch: {mxid}") + new_client.sync_store = SyncStoreProxy(self.db_instance) self.stop_sync() - - # TODO this event handler transfer is pretty hacky - self._remove_crypto_event_handlers() - self.client.crypto = None - new_client.event_handlers = self.client.event_handlers - new_client.global_event_handlers = self.client.global_event_handlers - self.client = new_client - self.homeserver = homeserver - self.access_token = access_token - self.device_id = device_id - if self.enable_crypto: - self._prepare_crypto() - await self._start_crypto() - else: - self.crypto_store = None - self.crypto = None + self.db_instance.homeserver = homeserver + self.db_instance.access_token = access_token self.start_sync() async def _update_remote_profile(self) -> None: profile = await self.client.get_profile(self.id) self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url - async def delete(self) -> None: - try: - del self.cache[self.id] - except KeyError: - pass - await super().delete() + # region Properties - @classmethod - @async_getter_lock - async def get( - cls, - user_id: UserID, - *, - homeserver: str | None = None, - access_token: str | None = None, - device_id: DeviceID | None = None, - ) -> Client | None: - try: - return cls.cache[user_id] - except KeyError: - pass + @property + def id(self) -> UserID: + return self.db_instance.id - user = cast(cls, await super().get(user_id)) - if user is not None: - user.postinit() - return user + @property + def homeserver(self) -> str: + return self.db_instance.homeserver - if homeserver and access_token: - user = cls( - user_id, - homeserver=homeserver, - access_token=access_token, - device_id=device_id or "", - ) - await user.insert() - user.postinit() - return user + @property + def access_token(self) -> str: + return self.db_instance.access_token - return None + @property + def device_id(self) -> str: + return self.db_instance.device_id - @classmethod - async def all(cls) -> AsyncGenerator[Client, None]: - users = await super().all() - user: cls - for user in users: - try: - yield cls.cache[user.id] - except KeyError: - user.postinit() - yield user + @property + def enabled(self) -> bool: + return self.db_instance.enabled + + @enabled.setter + def enabled(self, value: bool) -> None: + self.db_instance.enabled = value + + @property + def next_batch(self) -> SyncToken: + return self.db_instance.next_batch + + @property + def filter_id(self) -> FilterID: + return self.db_instance.filter_id + + @property + def sync(self) -> bool: + return self.db_instance.sync + + @sync.setter + def sync(self, value: bool) -> None: + if value == self.db_instance.sync: + return + self.db_instance.sync = value + if self.started: + if value: + self.start_sync() + else: + self.stop_sync() + + @property + def autojoin(self) -> bool: + return self.db_instance.autojoin + + @autojoin.setter + def autojoin(self, value: bool) -> None: + if value == self.db_instance.autojoin: + return + if value: + self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) + else: + self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite) + self.db_instance.autojoin = value + + @property + def online(self) -> bool: + return self.db_instance.online + + @online.setter + def online(self, value: bool) -> None: + self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE + self.db_instance.online = value + + @property + def displayname(self) -> str: + return self.db_instance.displayname + + @property + def avatar_url(self) -> ContentURI: + return self.db_instance.avatar_url + + # endregion + + +def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]: + Client.http_client = ClientSession(loop=loop) + Client.loop = loop + + if OlmMachine: + db_type = config["crypto_database.type"] + if db_type == "default": + db_url = config["database"] + parsed_url = URL(db_url) + if parsed_url.scheme == "sqlite": + Client.crypto_pickle_dir = config["crypto_database.pickle_dir"] + elif parsed_url.scheme == "postgres": + if not PgCryptoStore: + log.warning("Default database is postgres, but asyncpg is not installed. " + "Encryption will not work.") + else: + Client.crypto_db = AsyncDatabase(url=db_url, + upgrade_table=PgCryptoStore.upgrade_table) + elif db_type == "pickle": + Client.crypto_pickle_dir = config["crypto_database.pickle_dir"] + elif db_type == "postgres" and PgCryptoStore: + Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"], + upgrade_table=PgCryptoStore.upgrade_table) + else: + raise ValueError("Unsupported crypto database type") + + return Client.all() diff --git a/maubot/config.py b/maubot/config.py index b8e42de..34466cc 100644 --- a/maubot/config.py +++ b/maubot/config.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,10 +14,9 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . import random -import re import string - import bcrypt +import re from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper @@ -32,50 +31,36 @@ class Config(BaseFileConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: base = helper.base copy = helper.copy - - if "database" in self and self["database"].startswith("sqlite:///"): - helper.base["database"] = self["database"].replace("sqlite:///", "sqlite:") - else: - copy("database") - copy("database_opts") - if isinstance(self["crypto_database"], dict): - if self["crypto_database.type"] == "postgres": - base["crypto_database"] = self["crypto_database.postgres_uri"] - else: - copy("crypto_database") + copy("database") + copy("crypto_database.type") + copy("crypto_database.postgres_uri") + copy("crypto_database.pickle_dir") copy("plugin_directories.upload") copy("plugin_directories.load") copy("plugin_directories.trash") - if "plugin_directories.db" in self: - base["plugin_databases.sqlite"] = self["plugin_directories.db"] - else: - copy("plugin_databases.sqlite") - copy("plugin_databases.postgres") - copy("plugin_databases.postgres_opts") + copy("plugin_directories.db") copy("server.hostname") copy("server.port") copy("server.public_url") copy("server.listen") + copy("server.base_path") copy("server.ui_base_path") copy("server.plugin_base_path") copy("server.override_resource_path") + copy("server.appservice_base_path") shared_secret = self["server.unshared_secret"] if shared_secret is None or shared_secret == "generate": base["server.unshared_secret"] = self._new_token() else: base["server.unshared_secret"] = shared_secret - if "registration_secrets" in self: - base["homeservers"] = self["registration_secrets"] - else: - copy("homeservers") + copy("registration_secrets") copy("admins") for username, password in base["admins"].items(): if password and not bcrypt_regex.match(password): if password == "password": password = self._new_token() - base["admins"][username] = bcrypt.hashpw( - password.encode("utf-8"), bcrypt.gensalt() - ).decode("utf-8") + base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"), + bcrypt.gensalt()).decode("utf-8") copy("api_features.login") copy("api_features.plugin") copy("api_features.plugin_upload") diff --git a/maubot/db.py b/maubot/db.py new file mode 100644 index 0000000..3817882 --- /dev/null +++ b/maubot/db.py @@ -0,0 +1,101 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2019 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import Iterable, Optional +import logging +import sys + +from sqlalchemy import Column, String, Boolean, ForeignKey, Text +from sqlalchemy.engine.base import Engine +import sqlalchemy as sql + +from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI +from mautrix.util.db import Base +from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile + +from .config import Config + + +class DBPlugin(Base): + __tablename__ = "plugin" + + id: str = Column(String(255), primary_key=True) + type: str = Column(String(255), nullable=False) + enabled: bool = Column(Boolean, nullable=False, default=False) + primary_user: UserID = Column(String(255), + ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"), + nullable=False) + config: str = Column(Text, nullable=False, default='') + + @classmethod + def all(cls) -> Iterable['DBPlugin']: + return cls._select_all() + + @classmethod + def get(cls, id: str) -> Optional['DBPlugin']: + return cls._select_one_or_none(cls.c.id == id) + + +class DBClient(Base): + __tablename__ = "client" + + id: UserID = Column(String(255), primary_key=True) + homeserver: str = Column(String(255), nullable=False) + access_token: str = Column(Text, nullable=False) + device_id: DeviceID = Column(String(255), nullable=True) + enabled: bool = Column(Boolean, nullable=False, default=False) + + next_batch: SyncToken = Column(String(255), nullable=False, default="") + filter_id: FilterID = Column(String(255), nullable=False, default="") + + sync: bool = Column(Boolean, nullable=False, default=True) + autojoin: bool = Column(Boolean, nullable=False, default=True) + online: bool = Column(Boolean, nullable=False, default=True) + + displayname: str = Column(String(255), nullable=False, default="") + avatar_url: ContentURI = Column(String(255), nullable=False, default="") + + @classmethod + def all(cls) -> Iterable['DBClient']: + return cls._select_all() + + @classmethod + def get(cls, id: str) -> Optional['DBClient']: + return cls._select_one_or_none(cls.c.id == id) + + +def init(config: Config) -> Engine: + db = sql.create_engine(config["database"]) + Base.metadata.bind = db + + for table in (DBPlugin, DBClient, RoomState, UserProfile): + table.bind(db) + + if not db.has_table("alembic_version"): + log = logging.getLogger("maubot.db") + + if db.has_table("client") and db.has_table("plugin"): + log.warning("alembic_version table not found, but client and plugin tables found. " + "Assuming pre-Alembic database and inserting version.") + db.execute("CREATE TABLE IF NOT EXISTS alembic_version (" + " version_num VARCHAR(32) PRIMARY KEY" + ");") + db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');") + else: + log.critical("alembic_version table not found. " + "Did you forget to `alembic upgrade head`?") + sys.exit(10) + + return db diff --git a/maubot/db/__init__.py b/maubot/db/__init__.py deleted file mode 100644 index 68833ce..0000000 --- a/maubot/db/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from mautrix.util.async_db import Database - -from .client import Client -from .instance import DatabaseEngine, Instance -from .upgrade import upgrade_table - - -def init(db: Database) -> None: - for table in (Client, Instance): - table.db = db - - -__all__ = ["upgrade_table", "init", "Client", "Instance", "DatabaseEngine"] diff --git a/maubot/db/client.py b/maubot/db/client.py deleted file mode 100644 index 52f3a20..0000000 --- a/maubot/db/client.py +++ /dev/null @@ -1,114 +0,0 @@ -# maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from __future__ import annotations - -from typing import TYPE_CHECKING, ClassVar - -from asyncpg import Record -from attr import dataclass - -from mautrix.client import SyncStore -from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID -from mautrix.util.async_db import Database - -fake_db = Database.create("") if TYPE_CHECKING else None - - -@dataclass -class Client(SyncStore): - db: ClassVar[Database] = fake_db - - id: UserID - homeserver: str - access_token: str - device_id: DeviceID - enabled: bool - - next_batch: SyncToken - filter_id: FilterID - - sync: bool - autojoin: bool - online: bool - - displayname: str - avatar_url: ContentURI - - @classmethod - def _from_row(cls, row: Record | None) -> Client | None: - if row is None: - return None - return cls(**row) - - _columns = ( - "id, homeserver, access_token, device_id, enabled, next_batch, filter_id, " - "sync, autojoin, online, displayname, avatar_url" - ) - - @property - def _values(self): - return ( - self.id, - self.homeserver, - self.access_token, - self.device_id, - self.enabled, - self.next_batch, - self.filter_id, - self.sync, - self.autojoin, - self.online, - self.displayname, - self.avatar_url, - ) - - @classmethod - async def all(cls) -> list[Client]: - rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client") - return [cls._from_row(row) for row in rows] - - @classmethod - async def get(cls, id: str) -> Client | None: - q = f"SELECT {cls._columns} FROM client WHERE id=$1" - return cls._from_row(await cls.db.fetchrow(q, id)) - - async def insert(self) -> None: - q = """ - INSERT INTO client ( - id, homeserver, access_token, device_id, enabled, next_batch, filter_id, - sync, autojoin, online, displayname, avatar_url - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) - """ - await self.db.execute(q, *self._values) - - async def put_next_batch(self, next_batch: SyncToken) -> None: - await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id) - self.next_batch = next_batch - - async def get_next_batch(self) -> SyncToken: - return self.next_batch - - async def update(self) -> None: - q = """ - UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5, - next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10, - displayname=$11, avatar_url=$12 - WHERE id=$1 - """ - await self.db.execute(q, *self._values) - - async def delete(self) -> None: - await self.db.execute("DELETE FROM client WHERE id=$1", self.id) diff --git a/maubot/db/instance.py b/maubot/db/instance.py deleted file mode 100644 index 5bb3f6a..0000000 --- a/maubot/db/instance.py +++ /dev/null @@ -1,101 +0,0 @@ -# maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from __future__ import annotations - -from typing import TYPE_CHECKING, ClassVar -from enum import Enum - -from asyncpg import Record -from attr import dataclass - -from mautrix.types import UserID -from mautrix.util.async_db import Database - -fake_db = Database.create("") if TYPE_CHECKING else None - - -class DatabaseEngine(Enum): - SQLITE = "sqlite" - POSTGRES = "postgres" - - -@dataclass -class Instance: - db: ClassVar[Database] = fake_db - - id: str - type: str - enabled: bool - primary_user: UserID - config_str: str - database_engine: DatabaseEngine | None - - @property - def database_engine_str(self) -> str | None: - return self.database_engine.value if self.database_engine else None - - @classmethod - def _from_row(cls, row: Record | None) -> Instance | None: - if row is None: - return None - data = {**row} - db_engine = data.pop("database_engine", None) - return cls(**data, database_engine=DatabaseEngine(db_engine) if db_engine else None) - - _columns = "id, type, enabled, primary_user, config, database_engine" - - @classmethod - async def all(cls) -> list[Instance]: - q = f"SELECT {cls._columns} FROM instance" - rows = await cls.db.fetch(q) - return [cls._from_row(row) for row in rows] - - @classmethod - async def get(cls, id: str) -> Instance | None: - q = f"SELECT {cls._columns} FROM instance WHERE id=$1" - return cls._from_row(await cls.db.fetchrow(q, id)) - - async def update_id(self, new_id: str) -> None: - await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id) - self.id = new_id - - @property - def _values(self): - return ( - self.id, - self.type, - self.enabled, - self.primary_user, - self.config_str, - self.database_engine_str, - ) - - async def insert(self) -> None: - q = ( - "INSERT INTO instance (id, type, enabled, primary_user, config, database_engine) " - "VALUES ($1, $2, $3, $4, $5, $6)" - ) - await self.db.execute(q, *self._values) - - async def update(self) -> None: - q = """ - UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5, database_engine=$6 - WHERE id=$1 - """ - await self.db.execute(q, *self._values) - - async def delete(self) -> None: - await self.db.execute("DELETE FROM instance WHERE id=$1", self.id) diff --git a/maubot/db/upgrade/__init__.py b/maubot/db/upgrade/__init__.py deleted file mode 100644 index ed96422..0000000 --- a/maubot/db/upgrade/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from mautrix.util.async_db import UpgradeTable - -upgrade_table = UpgradeTable() - -from . import v01_initial_revision, v02_instance_database_engine diff --git a/maubot/db/upgrade/v01_initial_revision.py b/maubot/db/upgrade/v01_initial_revision.py deleted file mode 100644 index 2da8aff..0000000 --- a/maubot/db/upgrade/v01_initial_revision.py +++ /dev/null @@ -1,136 +0,0 @@ -# maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from __future__ import annotations - -from mautrix.util.async_db import Connection, Scheme - -from . import upgrade_table - -legacy_version_query = "SELECT version_num FROM alembic_version" -last_legacy_version = "90aa88820eab" - - -@upgrade_table.register(description="Initial asyncpg revision") -async def upgrade_v1(conn: Connection, scheme: Scheme) -> None: - if await conn.table_exists("alembic_version"): - await migrate_legacy_to_v1(conn, scheme) - else: - return await create_v1_tables(conn) - - -async def create_v1_tables(conn: Connection) -> None: - await conn.execute( - """CREATE TABLE client ( - id TEXT PRIMARY KEY, - homeserver TEXT NOT NULL, - access_token TEXT NOT NULL, - device_id TEXT NOT NULL, - enabled BOOLEAN NOT NULL, - - next_batch TEXT NOT NULL, - filter_id TEXT NOT NULL, - - sync BOOLEAN NOT NULL, - autojoin BOOLEAN NOT NULL, - online BOOLEAN NOT NULL, - - displayname TEXT NOT NULL, - avatar_url TEXT NOT NULL - )""" - ) - await conn.execute( - """CREATE TABLE instance ( - id TEXT PRIMARY KEY, - type TEXT NOT NULL, - enabled BOOLEAN NOT NULL, - primary_user TEXT NOT NULL, - config TEXT NOT NULL, - FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE - )""" - ) - - -async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None: - legacy_version = await conn.fetchval(legacy_version_query) - if legacy_version != last_legacy_version: - raise RuntimeError( - "Legacy database is not on last version. " - "Please upgrade the old database with alembic or drop it completely first." - ) - await conn.execute("ALTER TABLE plugin RENAME TO instance") - await update_state_store(conn, scheme) - if scheme != Scheme.SQLITE: - await varchar_to_text(conn) - await conn.execute("DROP TABLE alembic_version") - - -async def update_state_store(conn: Connection, scheme: Scheme) -> None: - # The Matrix state store already has more or less the correct schema, so set the version - await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)") - await conn.execute("INSERT INTO mx_version (version) VALUES (2)") - if scheme != Scheme.SQLITE: - # Remove old uppercase membership type and recreate it as lowercase - await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT") - await conn.execute("DROP TYPE IF EXISTS membership") - await conn.execute( - "CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')" - ) - await conn.execute( - "ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership " - "USING LOWER(membership)::membership" - ) - else: - # Recreate table to remove CHECK constraint and lowercase everything - await conn.execute( - """CREATE TABLE new_mx_user_profile ( - room_id TEXT, - user_id TEXT, - membership TEXT NOT NULL - CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')), - displayname TEXT, - avatar_url TEXT, - PRIMARY KEY (room_id, user_id) - )""" - ) - await conn.execute( - """ - INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url) - SELECT room_id, user_id, LOWER(membership), displayname, avatar_url - FROM mx_user_profile - """ - ) - await conn.execute("DROP TABLE mx_user_profile") - await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile") - - -async def varchar_to_text(conn: Connection) -> None: - columns_to_adjust = { - "client": ( - "id", - "homeserver", - "device_id", - "next_batch", - "filter_id", - "displayname", - "avatar_url", - ), - "instance": ("id", "type", "primary_user"), - "mx_room_state": ("room_id",), - "mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"), - } - for table, columns in columns_to_adjust.items(): - for column in columns: - await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT') diff --git a/maubot/db/upgrade/v02_instance_database_engine.py b/maubot/db/upgrade/v02_instance_database_engine.py deleted file mode 100644 index 7d2d7e7..0000000 --- a/maubot/db/upgrade/v02_instance_database_engine.py +++ /dev/null @@ -1,25 +0,0 @@ -# maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from __future__ import annotations - -from mautrix.util.async_db import Connection - -from . import upgrade_table - - -@upgrade_table.register(description="Store instance database engine") -async def upgrade_v2(conn: Connection) -> None: - await conn.execute("ALTER TABLE instance ADD COLUMN database_engine TEXT") diff --git a/maubot/example-config.yaml b/maubot/example-config.yaml deleted file mode 100644 index 0a6c8ac..0000000 --- a/maubot/example-config.yaml +++ /dev/null @@ -1,131 +0,0 @@ -# The full URI to the database. SQLite and Postgres are fully supported. -# Format examples: -# SQLite: sqlite:filename.db -# Postgres: postgresql://username:password@hostname/dbname -database: sqlite:maubot.db - -# Separate database URL for the crypto database. "default" means use the same database as above. -crypto_database: default - -# Additional arguments for asyncpg.create_pool() or sqlite3.connect() -# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool -# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect -# For sqlite, min_size is used as the connection thread pool size and max_size is ignored. -database_opts: - min_size: 1 - max_size: 10 - -# Configuration for storing plugin .mbp files -plugin_directories: - # The directory where uploaded new plugins should be stored. - upload: ./plugins - # The directories from which plugins should be loaded. - # Duplicate plugin IDs will be moved to the trash. - load: - - ./plugins - # The directory where old plugin versions and conflicting plugins should be moved. - # Set to "delete" to delete files immediately. - trash: ./trash - -# Configuration for storing plugin databases -plugin_databases: - # The directory where SQLite plugin databases should be stored. - sqlite: ./plugins - # The connection URL for plugin databases. If null, all plugins will get SQLite databases. - # If set, plugins using the new asyncpg interface will get a Postgres connection instead. - # Plugins using the legacy SQLAlchemy interface will always get a SQLite connection. - # - # To use the same connection pool as the default database, set to "default" - # (the default database above must be postgres to do this). - # - # When enabled, maubot will create separate Postgres schemas in the database for each plugin. - # To view schemas in psql, use `\dn`. To view enter and interact with a specific schema, - # use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal - # SQL queries/psql commands. - postgres: null - # Maximum number of connections per plugin instance. - postgres_max_conns_per_plugin: 3 - # Overrides for the default database_opts when using a non-"default" postgres connection string. - postgres_opts: {} - -server: - # The IP and port to listen to. - hostname: 0.0.0.0 - port: 29316 - # Public base URL where the server is visible. - public_url: https://example.com - # The base path for the UI. - ui_base_path: /_matrix/maubot - # The base path for plugin endpoints. The instance ID will be appended directly. - plugin_base_path: /_matrix/maubot/plugin/ - # Override path from where to load UI resources. - # Set to false to using pkg_resources to find the path. - override_resource_path: false - # The shared secret to sign API access tokens. - # Set to "generate" to generate and save a new token at startup. - unshared_secret: generate - -# Known homeservers. This is required for the `mbc auth` command and also allows -# more convenient access from the management UI. This is not required to create -# clients in the management UI, since you can also just type the homeserver URL -# into the box there. -homeservers: - matrix.org: - # Client-server API URL - url: https://matrix-client.matrix.org - # registration_shared_secret from synapse config - # You can leave this empty if you don't have access to the homeserver. - # When this is empty, `mbc auth --register` won't work, but `mbc auth` (login) will. - secret: null - -# List of administrator users. Each key is a username and the value is the password. -# Plaintext passwords will be bcrypted on startup. Set empty password to prevent normal login. -# Root is a special user that can't have a password and will always exist. -admins: - root: "" - -# API feature switches. -api_features: - login: true - plugin: true - plugin_upload: true - instance: true - instance_database: true - client: true - client_proxy: true - client_auth: true - dev_open: true - log: true - -# Python logging configuration. -# -# See section 16.7.2 of the Python documentation for more info: -# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema -logging: - version: 1 - formatters: - colored: - (): maubot.lib.color_log.ColorFormatter - format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" - normal: - format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" - handlers: - file: - class: logging.handlers.RotatingFileHandler - formatter: normal - filename: ./maubot.log - maxBytes: 10485760 - backupCount: 10 - console: - class: logging.StreamHandler - formatter: colored - loggers: - maubot: - level: DEBUG - mau: - level: DEBUG - aiohttp: - level: INFO - root: - level: DEBUG - handlers: [file, console] diff --git a/maubot/handlers/__init__.py b/maubot/handlers/__init__.py index e8567a2..1d9da7e 100644 --- a/maubot/handlers/__init__.py +++ b/maubot/handlers/__init__.py @@ -1 +1 @@ -from . import command, event, web +from . import event, command, web diff --git a/maubot/handlers/command.py b/maubot/handlers/command.py index 27e6547..81fccad 100644 --- a/maubot/handlers/command.py +++ b/maubot/handlers/command.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,46 +13,29 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Iterable, - List, - NewType, - Optional, - Pattern, - Sequence, - Set, - Tuple, - Union, -) +from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, + Dict, Tuple, Set, Iterable) from abc import ABC, abstractmethod import asyncio import functools import inspect import re -from mautrix.types import EventType, MessageType +from mautrix.types import MessageType, EventType from ..matrix import MaubotMessageEvent from . import event PrefixType = Optional[Union[str, Callable[[], str], Callable[[Any], str]]] -AliasesType = Union[ - List[str], Tuple[str, ...], Set[str], Callable[[str], bool], Callable[[Any, str], bool] -] -CommandHandlerFunc = NewType( - "CommandHandlerFunc", Callable[[MaubotMessageEvent, Any], Awaitable[Any]] -) -CommandHandlerDecorator = NewType( - "CommandHandlerDecorator", - Callable[[Union["CommandHandler", CommandHandlerFunc]], "CommandHandler"], -) -PassiveCommandHandlerDecorator = NewType( - "PassiveCommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc] -) +AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool], + Callable[[Any, str], bool]] +CommandHandlerFunc = NewType("CommandHandlerFunc", + Callable[[MaubotMessageEvent, Any], Awaitable[Any]]) +CommandHandlerDecorator = NewType("CommandHandlerDecorator", + Callable[[Union['CommandHandler', CommandHandlerFunc]], + 'CommandHandler']) +PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator", + Callable[[CommandHandlerFunc], CommandHandlerFunc]) def _split_in_two(val: str, split_by: str) -> List[str]: @@ -72,7 +55,7 @@ class CommandHandler: self.__mb_must_consume_args__: bool = True self.__mb_arg_fallthrough__: bool = True self.__mb_event_handler__: bool = True - self.__mb_event_types__: set[EventType] = {EventType.ROOM_MESSAGE} + self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE self.__mb_msgtypes__: Iterable[MessageType] = (MessageType.TEXT,) self.__bound_copies__: Dict[Any, CommandHandler] = {} self.__bound_instance__: Any = None @@ -84,27 +67,15 @@ class CommandHandler: return self.__bound_copies__[instance] except KeyError: new_ch = type(self)(self.__mb_func__) - keys = [ - "parent", - "subcommands", - "arguments", - "help", - "get_name", - "is_command_match", - "require_subcommand", - "must_consume_args", - "arg_fallthrough", - "event_handler", - "event_types", - "msgtypes", - ] + keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match", + "require_subcommand", "arg_fallthrough", "event_handler", "event_type", + "msgtypes"] for key in keys: key = f"__mb_{key}__" setattr(new_ch, key, getattr(self, key)) new_ch.__bound_instance__ = instance - new_ch.__mb_subcommands__ = [ - subcmd.__get__(instance, instancetype) for subcmd in self.__mb_subcommands__ - ] + new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype) + for subcmd in self.__mb_subcommands__] self.__bound_copies__[instance] = new_ch return new_ch @@ -112,13 +83,8 @@ class CommandHandler: def __command_match_unset(self, val: str) -> bool: raise NotImplementedError("Hmm") - async def __call__( - self, - evt: MaubotMessageEvent, - *, - _existing_args: Dict[str, Any] = None, - remaining_val: str = None, - ) -> Any: + async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None, + remaining_val: str = None) -> Any: if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__: return if remaining_val is None: @@ -154,25 +120,21 @@ class CommandHandler: return await self.__mb_func__(self.__bound_instance__, evt, **call_args) return await self.__mb_func__(evt, **call_args) - async def __call_subcommand__( - self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str - ) -> Tuple[bool, Any]: + async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], + remaining_val: str) -> Tuple[bool, Any]: command, remaining_val = _split_in_two(remaining_val.strip(), " ") for subcommand in self.__mb_subcommands__: if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command): - return True, await subcommand( - evt, _existing_args=call_args, remaining_val=remaining_val - ) + return True, await subcommand(evt, _existing_args=call_args, + remaining_val=remaining_val) return False, None - async def __parse_args__( - self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str - ) -> Tuple[bool, str]: + async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], + remaining_val: str) -> Tuple[bool, str]: for arg in self.__mb_arguments__: try: - remaining_val, call_args[arg.name] = arg.match( - remaining_val.strip(), evt=evt, instance=self.__bound_instance__ - ) + remaining_val, call_args[arg.name] = arg.match(remaining_val.strip(), evt=evt, + instance=self.__bound_instance__) if arg.required and call_args[arg.name] is None: raise ValueError("Argument required") except ArgumentSyntaxError as e: @@ -193,9 +155,8 @@ class CommandHandler: @property def __mb_usage_args__(self) -> str: - arg_usage = " ".join( - f"<{arg.label}>" if arg.required else f"[{arg.label}]" for arg in self.__mb_arguments__ - ) + arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]" + for arg in self.__mb_arguments__) if self.__mb_subcommands__ and self.__mb_arg_fallthrough__: arg_usage += " " + self.__mb_usage_subcommand__ return arg_usage @@ -211,19 +172,15 @@ class CommandHandler: @property def __mb_prefix__(self) -> str: if self.__mb_parent__: - return ( - f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} " - f"{self.__mb_name__}" - ) + return (f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} " + f"{self.__mb_name__}") return f"!{self.__mb_name__}" @property def __mb_usage_inline__(self) -> str: if not self.__mb_arg_fallthrough__: - return ( - f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n" - f"* {self.__mb_name__} {self.__mb_usage_subcommand__}" - ) + return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n" + f"* {self.__mb_name__} {self.__mb_usage_subcommand__}") return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}" @property @@ -235,10 +192,8 @@ class CommandHandler: if not self.__mb_arg_fallthrough__: if not self.__mb_arguments__: return f"**Usage:** {self.__mb_prefix__} [subcommand] [...]" - return ( - f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" - f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}" - ) + return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" + f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}") return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" @property @@ -247,25 +202,14 @@ class CommandHandler: return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}" return self.__mb_usage_without_subcommands__ - def subcommand( - self, - name: PrefixType = None, - *, - help: str = None, - aliases: AliasesType = None, - required_subcommand: bool = True, - arg_fallthrough: bool = True, - ) -> CommandHandlerDecorator: + def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, + required_subcommand: bool = True, arg_fallthrough: bool = True, + ) -> CommandHandlerDecorator: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) - new( - name, - help=help, - aliases=aliases, - require_subcommand=required_subcommand, - arg_fallthrough=arg_fallthrough, - )(func) + new(name, help=help, aliases=aliases, require_subcommand=required_subcommand, + arg_fallthrough=arg_fallthrough)(func) func.__mb_parent__ = self func.__mb_event_handler__ = False self.__mb_subcommands__.append(func) @@ -274,17 +218,10 @@ class CommandHandler: return decorator -def new( - name: PrefixType = None, - *, - help: str = None, - aliases: AliasesType = None, - event_type: EventType = EventType.ROOM_MESSAGE, - msgtypes: Iterable[MessageType] = None, - require_subcommand: bool = True, - arg_fallthrough: bool = True, - must_consume_args: bool = True, -) -> CommandHandlerDecorator: +def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, + event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: Iterable[MessageType] = None, + require_subcommand: bool = True, arg_fallthrough: bool = True, + must_consume_args: bool = True) -> CommandHandlerDecorator: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) @@ -298,16 +235,15 @@ def new( else: func.__mb_get_name__ = lambda self: name else: - func.__mb_get_name__ = lambda self: func.__mb_func__.__name__.replace("_", "-") + func.__mb_get_name__ = lambda self: func.__mb_func__.__name__ if callable(aliases): if len(inspect.getfullargspec(aliases).args) == 1: func.__mb_is_command_match__ = lambda self, val: aliases(val) else: func.__mb_is_command_match__ = aliases elif isinstance(aliases, (list, set, tuple)): - func.__mb_is_command_match__ = lambda self, val: ( - val == func.__mb_get_name__(self) or val in aliases - ) + func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_get_name__(self) + or val in aliases) else: func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self) # Decorators are executed last to first, so we reverse the argument list. @@ -315,7 +251,7 @@ def new( func.__mb_require_subcommand__ = require_subcommand func.__mb_arg_fallthrough__ = arg_fallthrough func.__mb_must_consume_args__ = must_consume_args - func.__mb_event_types__ = {event_type} + func.__mb_event_type__ = event_type if msgtypes: func.__mb_msgtypes__ = msgtypes return func @@ -331,9 +267,8 @@ class ArgumentSyntaxError(ValueError): class Argument(ABC): - def __init__( - self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False - ) -> None: + def __init__(self, name: str, label: str = None, *, required: bool = False, + pass_raw: bool = False) -> None: self.name = name self.label = label or name self.required = required @@ -351,15 +286,8 @@ class Argument(ABC): class RegexArgument(Argument): - def __init__( - self, - name: str, - label: str = None, - *, - required: bool = False, - pass_raw: bool = False, - matches: str = None, - ) -> None: + def __init__(self, name: str, label: str = None, *, required: bool = False, + pass_raw: bool = False, matches: str = None) -> None: super().__init__(name, label, required=required, pass_raw=pass_raw) matches = f"^{matches}" if self.pass_raw else f"^{matches}$" self.regex = re.compile(matches) @@ -370,23 +298,14 @@ class RegexArgument(Argument): val = re.split(r"\s", val, 1)[0] match = self.regex.match(val) if match: - return ( - orig_val[: match.start()] + orig_val[match.end() :], - match.groups() or val[match.start() : match.end()], - ) + return (orig_val[:match.start()] + orig_val[match.end():], + match.groups() or val[match.start():match.end()]) return orig_val, None class CustomArgument(Argument): - def __init__( - self, - name: str, - label: str = None, - *, - required: bool = False, - pass_raw: bool = False, - matcher: Callable[[str], Any], - ) -> None: + def __init__(self, name: str, label: str = None, *, required: bool = False, + pass_raw: bool = False, matcher: Callable[[str], Any]) -> None: super().__init__(name, label, required=required, pass_raw=pass_raw) self.matcher = matcher @@ -397,7 +316,7 @@ class CustomArgument(Argument): val = re.split(r"\s", val, 1)[0] res = self.matcher(val) if res is not None: - return orig_val[len(val) :], res + return orig_val[len(val):], res return orig_val, None @@ -406,18 +325,12 @@ class SimpleArgument(Argument): if self.pass_raw: return "", val res = re.split(r"\s", val, 1)[0] - return val[len(res) :], res + return val[len(res):], res -def argument( - name: str, - label: str = None, - *, - required: bool = True, - matches: Optional[str] = None, - parser: Optional[Callable[[str], Any]] = None, - pass_raw: bool = False, -) -> CommandHandlerDecorator: +def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None, + parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False + ) -> CommandHandlerDecorator: if matches: return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw) elif parser: @@ -426,17 +339,11 @@ def argument( return SimpleArgument(name, label, required=required, pass_raw=pass_raw) -def passive( - regex: Union[str, Pattern], - *, - msgtypes: Sequence[MessageType] = (MessageType.TEXT,), - field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body, - event_type: EventType = EventType.ROOM_MESSAGE, - multiple: bool = False, - case_insensitive: bool = False, - multiline: bool = False, - dot_all: bool = False, -) -> PassiveCommandHandlerDecorator: +def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,), + field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body, + event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False, + case_insensitive: bool = False, multiline: bool = False, dot_all: bool = False + ) -> PassiveCommandHandlerDecorator: if not isinstance(regex, Pattern): flags = re.RegexFlag.UNICODE if case_insensitive: @@ -465,14 +372,12 @@ def passive( return data = field(evt) if multiple: - val = [ - (data[match.pos : match.endpos], *match.groups()) - for match in regex.finditer(data) - ] + val = [(data[match.pos:match.endpos], *match.groups()) + for match in regex.finditer(data)] else: match = regex.search(data) if match: - val = (data[match.pos : match.endpos], *match.groups()) + val = (data[match.pos:match.endpos], *match.groups()) else: val = None if val: diff --git a/maubot/handlers/event.py b/maubot/handlers/event.py index 9be89b1..be02706 100644 --- a/maubot/handlers/event.py +++ b/maubot/handlers/event.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,26 +13,22 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations +from typing import Callable, Union, NewType -from typing import Callable, NewType - -from mautrix.client import EventHandler, InternalEventType from mautrix.types import EventType +from mautrix.client import EventHandler, InternalEventType EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler]) -def on(var: EventType | InternalEventType | EventHandler) -> EventHandlerDecorator | EventHandler: +def on(var: Union[EventType, InternalEventType, EventHandler] + ) -> Union[EventHandlerDecorator, EventHandler]: def decorator(func: EventHandler) -> EventHandler: func.__mb_event_handler__ = True if isinstance(var, (EventType, InternalEventType)): - if hasattr(func, "__mb_event_types__"): - func.__mb_event_types__.add(var) - else: - func.__mb_event_types__ = {var} + func.__mb_event_type__ = var else: - func.__mb_event_types__ = {EventType.ALL} + func.__mb_event_type__ = EventType.ALL return func diff --git a/maubot/handlers/web.py b/maubot/handlers/web.py index f170124..cf53d68 100644 --- a/maubot/handlers/web.py +++ b/maubot/handlers/web.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,9 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Any, Awaitable, Callable +from typing import Callable, Any, Awaitable -from aiohttp import hdrs, web +from aiohttp import web, hdrs WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]] WebHandlerDecorator = Callable[[WebHandler], WebHandler] diff --git a/maubot/instance.py b/maubot/instance.py index 8427e3c..8d1dea2 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,92 +13,58 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, AsyncGenerator, cast -from collections import defaultdict -import asyncio -import inspect -import io -import logging +from typing import Dict, List, Optional, Iterable, TYPE_CHECKING +from asyncio import AbstractEventLoop import os.path +import logging +import io -from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap +from ruamel.yaml import YAML +import sqlalchemy as sql -from mautrix.types import UserID -from mautrix.util import background_task -from mautrix.util.async_db import Database, Scheme, UpgradeTable -from mautrix.util.async_getter_lock import async_getter_lock from mautrix.util.config import BaseProxyConfig, RecursiveDict -from mautrix.util.logging import TraceLogger +from mautrix.types import UserID +from .db import DBPlugin +from .config import Config from .client import Client -from .db import DatabaseEngine, Instance as DBInstance -from .lib.optionalalchemy import Engine, MetaData, create_engine -from .lib.plugin_db import ProxyPostgresDatabase -from .loader import DatabaseType, PluginLoader, ZippedPluginLoader +from .loader import PluginLoader, ZippedPluginLoader from .plugin_base import Plugin if TYPE_CHECKING: - from .__main__ import Maubot - from .server import PluginWebApp + from .server import MaubotServer, PluginWebApp -log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance")) -db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db")) +log = logging.getLogger("maubot.instance") yaml = YAML() yaml.indent(4) yaml.width = 200 -class PluginInstance(DBInstance): - maubot: "Maubot" = None - cache: dict[str, PluginInstance] = {} - plugin_directories: list[str] = [] - _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) +class PluginInstance: + webserver: 'MaubotServer' = None + mb_config: Config = None + loop: AbstractEventLoop = None + cache: Dict[str, 'PluginInstance'] = {} + plugin_directories: List[str] = [] log: logging.Logger - loader: PluginLoader | None - client: Client | None - plugin: Plugin | None - config: BaseProxyConfig | None - base_cfg: RecursiveDict[CommentedMap] | None - base_cfg_str: str | None - inst_db: sql.engine.Engine | Database | None - inst_db_tables: dict | None - inst_webapp: PluginWebApp | None - inst_webapp_url: str | None + loader: PluginLoader + client: Client + plugin: Plugin + config: BaseProxyConfig + base_cfg: Optional[RecursiveDict[CommentedMap]] + base_cfg_str: Optional[str] + inst_db: sql.engine.Engine + inst_db_tables: Dict[str, sql.Table] + inst_webapp: Optional['PluginWebApp'] + inst_webapp_url: Optional[str] started: bool - def __init__( - self, - id: str, - type: str, - enabled: bool, - primary_user: UserID, - config: str = "", - database_engine: DatabaseEngine | None = None, - ) -> None: - super().__init__( - id=id, - type=type, - enabled=bool(enabled), - primary_user=primary_user, - config_str=config, - database_engine=database_engine, - ) - - def __hash__(self) -> int: - return hash(self.id) - - @classmethod - def init_cls(cls, maubot: "Maubot") -> None: - cls.maubot = maubot - - def postinit(self) -> None: + def __init__(self, db_instance: DBPlugin): + self.db_instance = db_instance self.log = log.getChild(self.id) - self.cache[self.id] = self self.config = None self.started = False self.loader = None @@ -110,6 +76,7 @@ class PluginInstance(DBInstance): self.inst_webapp_url = None self.base_cfg = None self.base_cfg_str = None + self.cache[self.id] = self def to_dict(self) -> dict: return { @@ -118,124 +85,35 @@ class PluginInstance(DBInstance): "enabled": self.enabled, "started": self.started, "primary_user": self.primary_user, - "config": self.config_str, + "config": self.db_instance.config, "base_config": self.base_cfg_str, - "database": ( - self.inst_db is not None and self.maubot.config["api_features.instance_database"] - ), - "database_interface": self.loader.meta.database_type_str if self.loader else "unknown", - "database_engine": self.database_engine_str, + "database": (self.inst_db is not None + and self.mb_config["api_features.instance_database"]), } - def _introspect_sqlalchemy(self) -> dict: - metadata = MetaData() - metadata.reflect(self.inst_db) - return { - table.name: { - "columns": { - column.name: { - "type": str(column.type), - "unique": column.unique or False, - "default": column.default, - "nullable": column.nullable, - "primary": column.primary_key, - } - for column in table.columns - }, - } - for table in metadata.tables.values() - } - - async def _introspect_sqlite(self) -> dict: - q = """ - SELECT - m.name AS table_name, - p.cid AS col_id, - p.name AS column_name, - p.type AS data_type, - p.pk AS is_primary, - p.dflt_value AS column_default, - p.[notnull] AS is_nullable - FROM sqlite_master m - LEFT JOIN pragma_table_info((m.name)) p - WHERE m.type = 'table' - ORDER BY table_name, col_id - """ - data = await self.inst_db.fetch(q) - tables = defaultdict(lambda: {"columns": {}}) - for column in data: - table_name = column["table_name"] - col_name = column["column_name"] - tables[table_name]["columns"][col_name] = { - "type": column["data_type"], - "nullable": bool(column["is_nullable"]), - "default": column["column_default"], - "primary": bool(column["is_primary"]), - # TODO uniqueness? - } - return tables - - async def _introspect_postgres(self) -> dict: - assert isinstance(self.inst_db, ProxyPostgresDatabase) - q = """ - SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default, - tc.constraint_type - FROM information_schema.columns col - LEFT JOIN information_schema.constraint_column_usage ccu - ON ccu.column_name=col.column_name - LEFT JOIN information_schema.table_constraints tc - ON col.table_name=tc.table_name - AND col.table_schema=tc.table_schema - AND ccu.constraint_name=tc.constraint_name - AND ccu.constraint_schema=tc.constraint_schema - AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE') - WHERE col.table_schema=$1 - """ - data = await self.inst_db.fetch(q, self.inst_db.schema_name) - tables = defaultdict(lambda: {"columns": {}}) - for column in data: - table_name = column["table_name"] - col_name = column["column_name"] - tables[table_name]["columns"].setdefault( - col_name, - { - "type": column["data_type"], - "nullable": column["is_nullable"], - "default": column["column_default"], - "primary": False, - "unique": False, - }, - ) - if column["constraint_type"] == "PRIMARY KEY": - tables[table_name]["columns"][col_name]["primary"] = True - elif column["constraint_type"] == "UNIQUE": - tables[table_name]["columns"][col_name]["unique"] = True - return tables - - async def get_db_tables(self) -> dict: - if self.inst_db_tables is None: - if isinstance(self.inst_db, Engine): - self.inst_db_tables = self._introspect_sqlalchemy() - elif self.inst_db.scheme == Scheme.SQLITE: - self.inst_db_tables = await self._introspect_sqlite() - else: - self.inst_db_tables = await self._introspect_postgres() + def get_db_tables(self) -> Dict[str, sql.Table]: + if not self.inst_db_tables: + metadata = sql.MetaData() + metadata.reflect(self.inst_db) + self.inst_db_tables = metadata.tables return self.inst_db_tables - async def load(self) -> bool: + def load(self) -> bool: if not self.loader: try: self.loader = PluginLoader.find(self.type) except KeyError: self.log.error(f"Failed to find loader for type {self.type}") - await self.update_enabled(False) + self.db_instance.enabled = False return False if not self.client: - self.client = await Client.get(self.primary_user) + self.client = Client.get(self.primary_user) if not self.client: self.log.error(f"Failed to get client for user {self.primary_user}") - await self.update_enabled(False) + self.db_instance.enabled = False return False + if self.loader.meta.database: + self.enable_database() if self.loader.meta.webapp: self.enable_webapp() self.log.debug("Plugin instance dependencies loaded") @@ -244,18 +122,18 @@ class PluginInstance(DBInstance): return True def enable_webapp(self) -> None: - self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id) + self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id) def disable_webapp(self) -> None: - self.maubot.server.remove_instance_webapp(self.id) + self.webserver.remove_instance_webapp(self.id) self.inst_webapp = None self.inst_webapp_url = None - @property - def _sqlite_db_path(self) -> str: - return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db") + def enable_database(self) -> None: + db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id) + self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db") - async def delete(self) -> None: + def delete(self) -> None: if self.loader is not None: self.loader.references.remove(self) if self.client is not None: @@ -264,89 +142,22 @@ class PluginInstance(DBInstance): del self.cache[self.id] except KeyError: pass - await super().delete() + self.db_instance.delete() if self.inst_db: - await self.stop_database() - await self.delete_database() + self.inst_db.dispose() + ZippedPluginLoader.trash( + os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"), + reason="deleted") if self.inst_webapp: self.disable_webapp() def load_config(self) -> CommentedMap: - return yaml.load(self.config_str) + return yaml.load(self.db_instance.config) def save_config(self, data: RecursiveDict[CommentedMap]) -> None: buf = io.StringIO() yaml.dump(data, buf) - val = buf.getvalue() - if val != self.config_str: - self.config_str = val - self.log.debug("Creating background task to save updated config") - background_task.create(self.update()) - - async def start_database( - self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True - ) -> None: - if self.loader.meta.database_type == DatabaseType.SQLALCHEMY: - if self.database_engine is None: - await self.update_db_engine(DatabaseEngine.SQLITE) - elif self.database_engine == DatabaseEngine.POSTGRES: - raise RuntimeError( - "Instance database engine is marked as Postgres, but plugin uses legacy " - "database interface, which doesn't support postgres." - ) - self.inst_db = create_engine(f"sqlite:///{self._sqlite_db_path}") - elif self.loader.meta.database_type == DatabaseType.ASYNCPG: - if self.database_engine is None: - if os.path.exists(self._sqlite_db_path) or not self.maubot.plugin_postgres_db: - await self.update_db_engine(DatabaseEngine.SQLITE) - else: - await self.update_db_engine(DatabaseEngine.POSTGRES) - instance_db_log = db_log.getChild(self.id) - if self.database_engine == DatabaseEngine.POSTGRES: - if not self.maubot.plugin_postgres_db: - raise RuntimeError( - "Instance database engine is marked as Postgres, but this maubot isn't " - "configured to support Postgres for plugin databases" - ) - self.inst_db = ProxyPostgresDatabase( - pool=self.maubot.plugin_postgres_db, - instance_id=self.id, - max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"], - upgrade_table=upgrade_table, - log=instance_db_log, - ) - else: - self.inst_db = Database.create( - f"sqlite:{self._sqlite_db_path}", - upgrade_table=upgrade_table, - log=instance_db_log, - ) - if actually_start: - await self.inst_db.start() - else: - raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}") - - async def stop_database(self) -> None: - if isinstance(self.inst_db, Database): - await self.inst_db.stop() - elif isinstance(self.inst_db, Engine): - self.inst_db.dispose() - else: - raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}") - - async def delete_database(self) -> None: - if self.loader.meta.database_type == DatabaseType.SQLALCHEMY: - ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted") - elif self.loader.meta.database_type == DatabaseType.ASYNCPG: - if self.inst_db is None: - await self.start_database(None, actually_start=False) - if isinstance(self.inst_db, ProxyPostgresDatabase): - await self.inst_db.delete() - else: - ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted") - else: - raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}") - self.inst_db = None + self.db_instance.config = buf.getvalue() async def start(self) -> None: if self.started: @@ -357,7 +168,7 @@ class PluginInstance(DBInstance): return if not self.client or not self.loader: self.log.warning("Missing plugin instance dependencies, attempting to load...") - if not await self.load(): + if not self.load(): return cls = await self.loader.load() if self.loader.meta.webapp and self.inst_webapp is None: @@ -366,13 +177,9 @@ class PluginInstance(DBInstance): elif not self.loader.meta.webapp and self.inst_webapp is not None: self.log.debug("Disabling webapp after plugin meta reload") self.disable_webapp() - if self.loader.meta.database: - try: - await self.start_database(cls.get_db_upgrade_table()) - except Exception: - self.log.exception("Failed to start instance database") - await self.update_enabled(False) - return + if self.loader.meta.database and self.inst_db is None: + self.log.debug("Enabling database after plugin meta reload") + self.enable_database() config_class = cls.get_config_class() if config_class: try: @@ -387,35 +194,23 @@ class PluginInstance(DBInstance): if self.base_cfg: base_cfg_func = self.base_cfg.clone else: - def base_cfg_func() -> None: return None - self.config = config_class(self.load_config, base_cfg_func, self.save_config) - self.plugin = cls( - client=self.client.client, - loop=self.maubot.loop, - http=self.client.http_client, - instance_id=self.id, - log=self.log, - config=self.config, - database=self.inst_db, - loader=self.loader, - webapp=self.inst_webapp, - webapp_url=self.inst_webapp_url, - ) + self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client, + instance_id=self.id, log=self.log, config=self.config, + database=self.inst_db, loader=self.loader, webapp=self.inst_webapp, + webapp_url=self.inst_webapp_url) try: await self.plugin.internal_start() except Exception: self.log.exception("Failed to start instance") - await self.update_enabled(False) + self.db_instance.enabled = False return self.started = True self.inst_db_tables = None - self.log.info( - f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} " - f"with user {self.client.id}" - ) + self.log.info(f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} " + f"with user {self.client.id}") async def stop(self) -> None: if not self.started: @@ -428,58 +223,63 @@ class PluginInstance(DBInstance): except Exception: self.log.exception("Failed to stop instance") self.plugin = None - if self.inst_db: - try: - await self.stop_database() - except Exception: - self.log.exception("Failed to stop instance database") self.inst_db_tables = None - async def update_id(self, new_id: str | None) -> None: - if new_id is not None and new_id.lower() != self.id: - await super().update_id(new_id.lower()) + @classmethod + def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None + ) -> Optional['PluginInstance']: + try: + return cls.cache[instance_id] + except KeyError: + db_instance = db_instance or DBPlugin.get(instance_id) + if not db_instance: + return None + return PluginInstance(db_instance) - async def update_config(self, config: str | None) -> None: - if config is None or self.config_str == config: + @classmethod + def all(cls) -> Iterable['PluginInstance']: + return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all()) + + def update_id(self, new_id: str) -> None: + if new_id is not None and new_id != self.id: + self.db_instance.id = new_id.lower() + + def update_config(self, config: str) -> None: + if not config or self.db_instance.config == config: return - self.config_str = config + self.db_instance.config = config if self.started and self.plugin is not None: - res = self.plugin.on_external_config_update() - if inspect.isawaitable(res): - await res - await self.update() + self.plugin.on_external_config_update() - async def update_primary_user(self, primary_user: UserID | None) -> bool: - if primary_user is None or primary_user == self.primary_user: + async def update_primary_user(self, primary_user: UserID) -> bool: + if not primary_user or primary_user == self.primary_user: return True - client = await Client.get(primary_user) + client = Client.get(primary_user) if not client: return False await self.stop() - self.primary_user = client.id + self.db_instance.primary_user = client.id if self.client: self.client.references.remove(self) self.client = client self.client.references.add(self) - await self.update() await self.start() self.log.debug(f"Primary user switched to {self.client.id}") return True - async def update_type(self, type: str | None) -> bool: - if type is None or type == self.type: + async def update_type(self, type: str) -> bool: + if not type or type == self.type: return True try: loader = PluginLoader.find(type) except KeyError: return False await self.stop() - self.type = loader.meta.id + self.db_instance.type = loader.meta.id if self.loader: self.loader.references.remove(self) self.loader = loader self.loader.references.add(self) - await self.update() await self.start() self.log.debug(f"Type switched to {self.loader.meta.id}") return True @@ -488,46 +288,38 @@ class PluginInstance(DBInstance): if started is not None and started != self.started: await (self.start() if started else self.stop()) - async def update_enabled(self, enabled: bool) -> None: + def update_enabled(self, enabled: bool) -> None: if enabled is not None and enabled != self.enabled: - self.enabled = enabled - await self.update() + self.db_instance.enabled = enabled - async def update_db_engine(self, db_engine: DatabaseEngine | None) -> None: - if db_engine is not None and db_engine != self.database_engine: - self.database_engine = db_engine - await self.update() + # region Properties - @classmethod - @async_getter_lock - async def get( - cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None - ) -> PluginInstance | None: - try: - return cls.cache[instance_id] - except KeyError: - pass + @property + def id(self) -> str: + return self.db_instance.id - instance = cast(cls, await super().get(instance_id)) - if instance is not None: - instance.postinit() - return instance + @id.setter + def id(self, value: str) -> None: + self.db_instance.id = value - if type and primary_user: - instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user) - await instance.insert() - instance.postinit() - return instance + @property + def type(self) -> str: + return self.db_instance.type - return None + @property + def enabled(self) -> bool: + return self.db_instance.enabled - @classmethod - async def all(cls) -> AsyncGenerator[PluginInstance, None]: - instances = await super().all() - instance: PluginInstance - for instance in instances: - try: - yield cls.cache[instance.id] - except KeyError: - instance.postinit() - yield instance + @property + def primary_user(self) -> UserID: + return self.db_instance.primary_user + + # endregion + + +def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop + ) -> Iterable[PluginInstance]: + PluginInstance.mb_config = config + PluginInstance.loop = loop + PluginInstance.webserver = webserver + return PluginInstance.all() diff --git a/maubot/lib/color_log.py b/maubot/lib/color_log.py index 8c36ed5..284cf74 100644 --- a/maubot/lib/color_log.py +++ b/maubot/lib/color_log.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2020 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,13 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from mautrix.util.logging.color import ( - MAU_COLOR, - MXID_COLOR, - PREFIX, - RESET, - ColorFormatter as BaseColorFormatter, -) +from mautrix.util.logging.color import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR, + MXID_COLOR, RESET) INST_COLOR = PREFIX + "35m" # magenta LOADER_COLOR = PREFIX + "36m" # blue @@ -28,22 +23,14 @@ LOADER_COLOR = PREFIX + "36m" # blue class ColorFormatter(BaseColorFormatter): def _color_name(self, module: str) -> str: client = "maubot.client" - if module.startswith(client + "."): - suffix = "" - if module.endswith(".crypto"): - suffix = f".{MAU_COLOR}crypto{RESET}" - module = module[: -len(".crypto")] - module = module[len(client) + 1 :] - return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}" + if module.startswith(client): + return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}" instance = "maubot.instance" - if module.startswith(instance + "."): + if module.startswith(instance): return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}" - instance_db = "maubot.instance_db" - if module.startswith(instance_db + "."): - return f"{MAU_COLOR}{instance_db}{RESET}.{INST_COLOR}{module[len(instance_db) + 1:]}{RESET}" loader = "maubot.loader" - if module.startswith(loader + "."): + if module.startswith(loader): return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}" - if module.startswith("maubot."): + if module.startswith("maubot"): return f"{MAU_COLOR}{module}{RESET}" return super()._color_name(module) diff --git a/maubot/lib/future_awaitable.py b/maubot/lib/future_awaitable.py deleted file mode 100644 index 388eae9..0000000 --- a/maubot/lib/future_awaitable.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Any, Awaitable, Callable, Generator - - -class FutureAwaitable: - def __init__(self, func: Callable[[], Awaitable[None]]) -> None: - self._func = func - - def __await__(self) -> Generator[Any, None, None]: - return self._func().__await__() diff --git a/maubot/lib/optionalalchemy.py b/maubot/lib/optionalalchemy.py deleted file mode 100644 index ba94271..0000000 --- a/maubot/lib/optionalalchemy.py +++ /dev/null @@ -1,19 +0,0 @@ -try: - from sqlalchemy import MetaData, asc, create_engine, desc - from sqlalchemy.engine import Engine - from sqlalchemy.exc import IntegrityError, OperationalError -except ImportError: - - class FakeError(Exception): - pass - - class FakeType: - def __init__(self, *args, **kwargs): - raise Exception("SQLAlchemy is not installed") - - def create_engine(*args, **kwargs): - raise Exception("SQLAlchemy is not installed") - - MetaData = Engine = FakeType - IntegrityError = OperationalError = FakeError - asc = desc = lambda a: a diff --git a/maubot/lib/plugin_db.py b/maubot/lib/plugin_db.py deleted file mode 100644 index 977a619..0000000 --- a/maubot/lib/plugin_db.py +++ /dev/null @@ -1,100 +0,0 @@ -# maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from __future__ import annotations - -from contextlib import asynccontextmanager -import asyncio - -from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable -from mautrix.util.async_db.connection import LoggingConnection -from mautrix.util.logging import TraceLogger - -remove_double_quotes = str.maketrans({'"': "_"}) - - -class ProxyPostgresDatabase(Database): - scheme = Scheme.POSTGRES - _underlying_pool: PostgresDatabase - schema_name: str - _quoted_schema: str - _default_search_path: str - _conn_sema: asyncio.Semaphore - _max_conns: int - - def __init__( - self, - pool: PostgresDatabase, - instance_id: str, - max_conns: int, - upgrade_table: UpgradeTable | None, - log: TraceLogger | None = None, - ) -> None: - super().__init__(pool.url, upgrade_table=upgrade_table, log=log) - self._underlying_pool = pool - # Simple accidental SQL injection prevention. - # Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway. - self.schema_name = f"mbp_{instance_id.translate(remove_double_quotes)}" - self._quoted_schema = f'"{self.schema_name}"' - self._default_search_path = '"$user", public' - self._conn_sema = asyncio.BoundedSemaphore(max_conns) - self._max_conns = max_conns - - async def start(self) -> None: - async with self._underlying_pool.acquire() as conn: - self._default_search_path = await conn.fetchval("SHOW search_path") - self.log.trace(f"Found default search path: {self._default_search_path}") - await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._quoted_schema}") - await super().start() - - async def stop(self) -> None: - for _ in range(self._max_conns): - try: - await asyncio.wait_for(self._conn_sema.acquire(), timeout=3) - except asyncio.TimeoutError: - self.log.warning( - "Failed to drain plugin database connection pool, " - "the plugin may be leaking database connections" - ) - break - - async def delete(self) -> None: - self.log.info(f"Deleting schema {self.schema_name} and all data in it") - try: - await self._underlying_pool.execute( - f"DROP SCHEMA IF EXISTS {self._quoted_schema} CASCADE" - ) - except Exception: - self.log.warning("Failed to delete schema", exc_info=True) - - @asynccontextmanager - async def acquire(self) -> LoggingConnection: - conn: LoggingConnection - async with self._conn_sema, self._underlying_pool.acquire() as conn: - await conn.execute(f"SET search_path = {self._quoted_schema}") - try: - yield conn - finally: - if not conn.wrapped.is_closed(): - try: - await conn.execute(f"SET search_path = {self._default_search_path}") - except Exception: - self.log.exception("Error resetting search_path after use") - await conn.wrapped.close() - else: - self.log.debug("Connection was closed after use, not resetting search_path") - - -__all__ = ["ProxyPostgresDatabase"] diff --git a/maubot/lib/state_store.py b/maubot/lib/store_proxy.py similarity index 61% rename from maubot/lib/state_store.py rename to maubot/lib/store_proxy.py index 81fb5fd..6e402aa 100644 --- a/maubot/lib/state_store.py +++ b/maubot/lib/store_proxy.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,15 +13,16 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore +from mautrix.client import SyncStore +from mautrix.types import SyncToken -try: - from mautrix.crypto import StateStore as CryptoStateStore - class PgStateStore(BasePgStateStore, CryptoStateStore): - pass +class SyncStoreProxy(SyncStore): + def __init__(self, db_instance) -> None: + self.db_instance = db_instance -except ImportError as e: - PgStateStore = BasePgStateStore + async def put_next_batch(self, next_batch: SyncToken) -> None: + self.db_instance.edit(next_batch=next_batch) -__all__ = ["PgStateStore"] + async def get_next_batch(self) -> SyncToken: + return self.db_instance.next_batch diff --git a/maubot/lib/zipimport.py b/maubot/lib/zipimport.py index e7b77db..f9a0ca7 100644 --- a/maubot/lib/zipimport.py +++ b/maubot/lib/zipimport.py @@ -18,28 +18,26 @@ used by the builtin import mechanism for sys.path items that are paths to Zip archives. """ -from importlib import _bootstrap # for _verbose_message from importlib import _bootstrap_external +from importlib import _bootstrap # for _verbose_message +import _imp # for check_hash_based_pycs +import _io # for open import marshal # for loads import sys # for modules import time # for mktime -import _imp # for check_hash_based_pycs -import _io # for open - -__all__ = ["ZipImportError", "zipimporter"] +__all__ = ['ZipImportError', 'zipimporter'] def _unpack_uint32(data): """Convert 4 bytes in little-endian to an integer.""" assert len(data) == 4 - return int.from_bytes(data, "little") - + return int.from_bytes(data, 'little') def _unpack_uint16(data): """Convert 2 bytes in little-endian to an integer.""" assert len(data) == 2 - return int.from_bytes(data, "little") + return int.from_bytes(data, 'little') path_sep = _bootstrap_external.path_sep @@ -49,17 +47,15 @@ alt_path_sep = _bootstrap_external.path_separators[1:] class ZipImportError(ImportError): pass - # _read_directory() cache _zip_directory_cache = {} _module_type = type(sys) END_CENTRAL_DIR_SIZE = 22 -STRING_END_ARCHIVE = b"PK\x05\x06" +STRING_END_ARCHIVE = b'PK\x05\x06' MAX_COMMENT_LEN = (1 << 16) - 1 - class zipimporter: """zipimporter(archivepath) -> zipimporter object @@ -81,10 +77,9 @@ class zipimporter: def __init__(self, path): if not isinstance(path, str): import os - path = os.fsdecode(path) if not path: - raise ZipImportError("archive path is empty", path=path) + raise ZipImportError('archive path is empty', path=path) if alt_path_sep: path = path.replace(alt_path_sep, path_sep) @@ -97,14 +92,14 @@ class zipimporter: # Back up one path element. dirname, basename = _bootstrap_external._path_split(path) if dirname == path: - raise ZipImportError("not a Zip file", path=path) + raise ZipImportError('not a Zip file', path=path) path = dirname prefix.append(basename) else: # it exists if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG # it's a not file - raise ZipImportError("not a Zip file", path=path) + raise ZipImportError('not a Zip file', path=path) break try: @@ -159,10 +154,11 @@ class zipimporter: # This is possibly a portion of a namespace # package. Return the string representing its path, # without a trailing separator. - return None, [f"{self.archive}{path_sep}{modpath}"] + return None, [f'{self.archive}{path_sep}{modpath}'] return None, [] + # Check whether we can satisfy the import of the module named by # 'fullname'. Return self if we can, None if we can't. def find_module(self, fullname, path=None): @@ -176,6 +172,7 @@ class zipimporter: """ return self.find_loader(fullname, path)[0] + def get_code(self, fullname): """get_code(fullname) -> code object. @@ -185,6 +182,7 @@ class zipimporter: code, ispackage, modpath = _get_module_code(self, fullname) return code + def get_data(self, pathname): """get_data(pathname) -> string with file data. @@ -196,14 +194,15 @@ class zipimporter: key = pathname if pathname.startswith(self.archive + path_sep): - key = pathname[len(self.archive + path_sep) :] + key = pathname[len(self.archive + path_sep):] try: toc_entry = self._files[key] except KeyError: - raise OSError(0, "", key) + raise OSError(0, '', key) return _get_data(self.archive, toc_entry) + # Return a string matching __file__ for the named module def get_filename(self, fullname): """get_filename(fullname) -> filename string. @@ -215,6 +214,7 @@ class zipimporter: code, ispackage, modpath = _get_module_code(self, fullname) return modpath + def get_source(self, fullname): """get_source(fullname) -> source string. @@ -228,9 +228,9 @@ class zipimporter: path = _get_module_path(self, fullname) if mi: - fullpath = _bootstrap_external._path_join(path, "__init__.py") + fullpath = _bootstrap_external._path_join(path, '__init__.py') else: - fullpath = f"{path}.py" + fullpath = f'{path}.py' try: toc_entry = self._files[fullpath] @@ -239,6 +239,7 @@ class zipimporter: return None return _get_data(self.archive, toc_entry).decode() + # Return a bool signifying whether the module is a package or not. def is_package(self, fullname): """is_package(fullname) -> bool. @@ -251,6 +252,7 @@ class zipimporter: raise ZipImportError(f"can't find module {fullname!r}", name=fullname) return mi + # Load and return the module named by 'fullname'. def load_module(self, fullname): """load_module(fullname) -> module. @@ -274,7 +276,7 @@ class zipimporter: fullpath = _bootstrap_external._path_join(self.archive, path) mod.__path__ = [fullpath] - if not hasattr(mod, "__builtins__"): + if not hasattr(mod, '__builtins__'): mod.__builtins__ = __builtins__ _bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath) exec(code, mod.__dict__) @@ -285,10 +287,11 @@ class zipimporter: try: mod = sys.modules[fullname] except KeyError: - raise ImportError(f"Loaded module {fullname!r} not found in sys.modules") - _bootstrap._verbose_message("import {} # loaded from Zip {}", fullname, modpath) + raise ImportError(f'Loaded module {fullname!r} not found in sys.modules') + _bootstrap._verbose_message('import {} # loaded from Zip {}', fullname, modpath) return mod + def get_resource_reader(self, fullname): """Return the ResourceReader for a package in a zip file. @@ -302,11 +305,11 @@ class zipimporter: return None if not _ZipImportResourceReader._registered: from importlib.abc import ResourceReader - ResourceReader.register(_ZipImportResourceReader) _ZipImportResourceReader._registered = True return _ZipImportResourceReader(self, fullname) + def __repr__(self): return f'' @@ -317,18 +320,16 @@ class zipimporter: # are swapped by initzipimport() if we run in optimized mode. Also, # '/' is replaced by path_sep there. _zip_searchorder = ( - (path_sep + "__init__.pyc", True, True), - (path_sep + "__init__.py", False, True), - (".pyc", True, False), - (".py", False, False), + (path_sep + '__init__.pyc', True, True), + (path_sep + '__init__.py', False, True), + ('.pyc', True, False), + ('.py', False, False), ) - # Given a module name, return the potential file path in the # archive (without extension). def _get_module_path(self, fullname): - return self.prefix + fullname.rpartition(".")[2] - + return self.prefix + fullname.rpartition('.')[2] # Does this path represent a directory? def _is_dir(self, path): @@ -339,7 +340,6 @@ def _is_dir(self, path): # If dirpath is present in self._files, we have a directory. return dirpath in self._files - # Return some information about a module. def _get_module_info(self, fullname): path = _get_module_path(self, fullname) @@ -352,7 +352,6 @@ def _get_module_info(self, fullname): # implementation - # _read_directory(archive) -> files dict (new reference) # # Given a path to a Zip archive, build a dict, mapping file names @@ -375,7 +374,7 @@ def _get_module_info(self, fullname): # data_size and file_offset are 0. def _read_directory(archive): try: - fp = _io.open(archive, "rb") + fp = _io.open(archive, 'rb') except OSError: raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive) @@ -395,33 +394,36 @@ def _read_directory(archive): fp.seek(0, 2) file_size = fp.tell() except OSError: - raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) - max_comment_start = max(file_size - MAX_COMMENT_LEN - END_CENTRAL_DIR_SIZE, 0) + raise ZipImportError(f"can't read Zip file: {archive!r}", + path=archive) + max_comment_start = max(file_size - MAX_COMMENT_LEN - + END_CENTRAL_DIR_SIZE, 0) try: fp.seek(max_comment_start) data = fp.read() except OSError: - raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) + raise ZipImportError(f"can't read Zip file: {archive!r}", + path=archive) pos = data.rfind(STRING_END_ARCHIVE) if pos < 0: - raise ZipImportError(f"not a Zip file: {archive!r}", path=archive) - buffer = data[pos : pos + END_CENTRAL_DIR_SIZE] + raise ZipImportError(f'not a Zip file: {archive!r}', + path=archive) + buffer = data[pos:pos+END_CENTRAL_DIR_SIZE] if len(buffer) != END_CENTRAL_DIR_SIZE: - raise ZipImportError(f"corrupt Zip file: {archive!r}", path=archive) + raise ZipImportError(f"corrupt Zip file: {archive!r}", + path=archive) header_position = file_size - len(data) + pos header_size = _unpack_uint32(buffer[12:16]) header_offset = _unpack_uint32(buffer[16:20]) if header_position < header_size: - raise ZipImportError(f"bad central directory size: {archive!r}", path=archive) + raise ZipImportError(f'bad central directory size: {archive!r}', path=archive) if header_position < header_offset: - raise ZipImportError(f"bad central directory offset: {archive!r}", path=archive) + raise ZipImportError(f'bad central directory offset: {archive!r}', path=archive) header_position -= header_size arc_offset = header_position - header_offset if arc_offset < 0: - raise ZipImportError( - f"bad central directory size or offset: {archive!r}", path=archive - ) + raise ZipImportError(f'bad central directory size or offset: {archive!r}', path=archive) files = {} # Start of Central Directory @@ -433,12 +435,12 @@ def _read_directory(archive): while True: buffer = fp.read(46) if len(buffer) < 4: - raise EOFError("EOF read where not expected") + raise EOFError('EOF read where not expected') # Start of file header - if buffer[:4] != b"PK\x01\x02": - break # Bad: Central Dir File Header + if buffer[:4] != b'PK\x01\x02': + break # Bad: Central Dir File Header if len(buffer) != 46: - raise EOFError("EOF read where not expected") + raise EOFError('EOF read where not expected') flags = _unpack_uint16(buffer[8:10]) compress = _unpack_uint16(buffer[10:12]) time = _unpack_uint16(buffer[12:14]) @@ -452,7 +454,7 @@ def _read_directory(archive): file_offset = _unpack_uint32(buffer[42:46]) header_size = name_size + extra_size + comment_size if file_offset > header_offset: - raise ZipImportError(f"bad local header offset: {archive!r}", path=archive) + raise ZipImportError(f'bad local header offset: {archive!r}', path=archive) file_offset += arc_offset try: @@ -476,19 +478,18 @@ def _read_directory(archive): else: # Historical ZIP filename encoding try: - name = name.decode("ascii") + name = name.decode('ascii') except UnicodeDecodeError: - name = name.decode("latin1").translate(cp437_table) + name = name.decode('latin1').translate(cp437_table) - name = name.replace("/", path_sep) + name = name.replace('/', path_sep) path = _bootstrap_external._path_join(archive, name) t = (path, compress, data_size, file_size, file_offset, time, date, crc) files[name] = t count += 1 - _bootstrap._verbose_message("zipimport: found {} names in {!r}", count, archive) + _bootstrap._verbose_message('zipimport: found {} names in {!r}', count, archive) return files - # During bootstrap, we may need to load the encodings # package from a ZIP file. But the cp437 encoding is implemented # in Python in the encodings package. @@ -497,36 +498,35 @@ def _read_directory(archive): # the cp437 encoding. cp437_table = ( # ASCII part, 8 rows x 16 chars - "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f" - "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f" - " !\"#$%&'()*+,-./" - "0123456789:;<=>?" - "@ABCDEFGHIJKLMNO" - "PQRSTUVWXYZ[\\]^_" - "`abcdefghijklmno" - "pqrstuvwxyz{|}~\x7f" + '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f' + '\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f' + ' !"#$%&\'()*+,-./' + '0123456789:;<=>?' + '@ABCDEFGHIJKLMNO' + 'PQRSTUVWXYZ[\\]^_' + '`abcdefghijklmno' + 'pqrstuvwxyz{|}~\x7f' # non-ASCII part, 16 rows x 8 chars - "\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7" - "\xea\xeb\xe8\xef\xee\xec\xc4\xc5" - "\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9" - "\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192" - "\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba" - "\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb" - "\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556" - "\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510" - "\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f" - "\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567" - "\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b" - "\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580" - "\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4" - "\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229" - "\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248" - "\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0" + '\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7' + '\xea\xeb\xe8\xef\xee\xec\xc4\xc5' + '\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9' + '\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192' + '\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba' + '\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb' + '\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556' + '\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510' + '\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f' + '\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567' + '\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b' + '\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580' + '\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4' + '\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229' + '\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248' + '\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0' ) _importing_zlib = False - # Return the zlib.decompress function object, or NULL if zlib couldn't # be imported. The function is cached when found, so subsequent calls # don't import zlib again. @@ -535,29 +535,28 @@ def _get_decompress_func(): if _importing_zlib: # Someone has a zlib.py[co] in their Zip file # let's avoid a stack overflow. - _bootstrap._verbose_message("zipimport: zlib UNAVAILABLE") + _bootstrap._verbose_message('zipimport: zlib UNAVAILABLE') raise ZipImportError("can't decompress data; zlib not available") _importing_zlib = True try: from zlib import decompress except Exception: - _bootstrap._verbose_message("zipimport: zlib UNAVAILABLE") + _bootstrap._verbose_message('zipimport: zlib UNAVAILABLE') raise ZipImportError("can't decompress data; zlib not available") finally: _importing_zlib = False - _bootstrap._verbose_message("zipimport: zlib available") + _bootstrap._verbose_message('zipimport: zlib available') return decompress - # Given a path to a Zip file and a toc_entry, return the (uncompressed) data. def _get_data(archive, toc_entry): datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry if data_size < 0: - raise ZipImportError("negative data size") + raise ZipImportError('negative data size') - with _io.open(archive, "rb") as fp: + with _io.open(archive, 'rb') as fp: # Check to make sure the local file header is correct try: fp.seek(file_offset) @@ -565,11 +564,11 @@ def _get_data(archive, toc_entry): raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) buffer = fp.read(30) if len(buffer) != 30: - raise EOFError("EOF read where not expected") + raise EOFError('EOF read where not expected') - if buffer[:4] != b"PK\x03\x04": + if buffer[:4] != b'PK\x03\x04': # Bad: Local File Header - raise ZipImportError(f"bad local file header: {archive!r}", path=archive) + raise ZipImportError(f'bad local file header: {archive!r}', path=archive) name_size = _unpack_uint16(buffer[26:28]) extra_size = _unpack_uint16(buffer[28:30]) @@ -602,17 +601,16 @@ def _eq_mtime(t1, t2): # dostime only stores even seconds, so be lenient return abs(t1 - t2) <= 1 - # Given the contents of a .py[co] file, unmarshal the data # and return the code object. Return None if it the magic word doesn't # match (we do this instead of raising an exception as we fall back # to .py if available and we don't want to mask other errors). def _unmarshal_code(pathname, data, mtime): if len(data) < 16: - raise ZipImportError("bad pyc data") + raise ZipImportError('bad pyc data') if data[:4] != _bootstrap_external.MAGIC_NUMBER: - _bootstrap._verbose_message("{!r} has bad magic", pathname) + _bootstrap._verbose_message('{!r} has bad magic', pathname) return None # signal caller to try alternative flags = _unpack_uint32(data[4:8]) @@ -621,57 +619,47 @@ def _unmarshal_code(pathname, data, mtime): # pycs. We could validate hash-based pycs against the source, but it # seems likely that most people putting hash-based pycs in a zipfile # will use unchecked ones. - if _imp.check_hash_based_pycs != "never" and ( - flags != 0x1 or _imp.check_hash_based_pycs == "always" - ): + if (_imp.check_hash_based_pycs != 'never' and + (flags != 0x1 or _imp.check_hash_based_pycs == 'always')): return None elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime): - _bootstrap._verbose_message("{!r} has bad mtime", pathname) + _bootstrap._verbose_message('{!r} has bad mtime', pathname) return None # signal caller to try alternative # XXX the pyc's size field is ignored; timestamp collisions are probably # unimportant with zip files. code = marshal.loads(data[16:]) if not isinstance(code, _code_type): - raise TypeError(f"compiled module {pathname!r} is not a code object") + raise TypeError(f'compiled module {pathname!r} is not a code object') return code - _code_type = type(_unmarshal_code.__code__) # Replace any occurrences of '\r\n?' in the input string with '\n'. # This converts DOS and Mac line endings to Unix line endings. def _normalize_line_endings(source): - source = source.replace(b"\r\n", b"\n") - source = source.replace(b"\r", b"\n") + source = source.replace(b'\r\n', b'\n') + source = source.replace(b'\r', b'\n') return source - # Given a string buffer containing Python source code, compile it # and return a code object. def _compile_source(pathname, source): source = _normalize_line_endings(source) - return compile(source, pathname, "exec", dont_inherit=True) - + return compile(source, pathname, 'exec', dont_inherit=True) # Convert the date/time values found in the Zip archive to a value # that's compatible with the time stamp stored in .pyc files. def _parse_dostime(d, t): - return time.mktime( - ( - (d >> 9) + 1980, # bits 9..15: year - (d >> 5) & 0xF, # bits 5..8: month - d & 0x1F, # bits 0..4: day - t >> 11, # bits 11..15: hours - (t >> 5) & 0x3F, # bits 8..10: minutes - (t & 0x1F) * 2, # bits 0..7: seconds / 2 - -1, - -1, - -1, - ) - ) - + return time.mktime(( + (d >> 9) + 1980, # bits 9..15: year + (d >> 5) & 0xF, # bits 5..8: month + d & 0x1F, # bits 0..4: day + t >> 11, # bits 11..15: hours + (t >> 5) & 0x3F, # bits 8..10: minutes + (t & 0x1F) * 2, # bits 0..7: seconds / 2 + -1, -1, -1)) # Given a path to a .pyc file in the archive, return the # modification time of the matching .py file, or 0 if no source @@ -679,7 +667,7 @@ def _parse_dostime(d, t): def _get_mtime_of_source(self, path): try: # strip 'c' or 'o' from *.py[co] - assert path[-1:] in ("c", "o") + assert path[-1:] in ('c', 'o') path = path[:-1] toc_entry = self._files[path] # fetch the time stamp of the .py file for comparison @@ -690,14 +678,13 @@ def _get_mtime_of_source(self, path): except (KeyError, IndexError, TypeError): return 0 - # Get the code object associated with the module specified by # 'fullname'. def _get_module_code(self, fullname): path = _get_module_path(self, fullname) for suffix, isbytecode, ispackage in _zip_searchorder: fullpath = path + suffix - _bootstrap._verbose_message("trying {}{}{}", self.archive, path_sep, fullpath, verbosity=2) + _bootstrap._verbose_message('trying {}{}{}', self.archive, path_sep, fullpath, verbosity=2) try: toc_entry = self._files[fullpath] except KeyError: @@ -726,7 +713,6 @@ class _ZipImportResourceReader: This class is allowed to reference all the innards and private parts of the zipimporter. """ - _registered = False def __init__(self, zipimporter, fullname): @@ -734,10 +720,9 @@ class _ZipImportResourceReader: self.fullname = fullname def open_resource(self, resource): - fullname_as_path = self.fullname.replace(".", "/") - path = f"{fullname_as_path}/{resource}" + fullname_as_path = self.fullname.replace('.', '/') + path = f'{fullname_as_path}/{resource}' from io import BytesIO - try: return BytesIO(self.zipimporter.get_data(path)) except OSError: @@ -752,8 +737,8 @@ class _ZipImportResourceReader: def is_resource(self, name): # Maybe we could do better, but if we can get the data, it's a # resource. Otherwise it isn't. - fullname_as_path = self.fullname.replace(".", "/") - path = f"{fullname_as_path}/{name}" + fullname_as_path = self.fullname.replace('.', '/') + path = f'{fullname_as_path}/{name}' try: self.zipimporter.get_data(path) except OSError: @@ -769,12 +754,11 @@ class _ZipImportResourceReader: # top of the archive, and then we iterate through _files looking for # names inside that "directory". from pathlib import Path - fullname_path = Path(self.zipimporter.get_filename(self.fullname)) relative_path = fullname_path.relative_to(self.zipimporter.archive) # Don't forget that fullname names a package, so its path will include # __init__.py, which we want to ignore. - assert relative_path.name == "__init__.py" + assert relative_path.name == '__init__.py' package_path = relative_path.parent subdirs_seen = set() for filename in self.zipimporter._files: diff --git a/maubot/loader/__init__.py b/maubot/loader/__init__.py index d61be5c..b783152 100644 --- a/maubot/loader/__init__.py +++ b/maubot/loader/__init__.py @@ -1,3 +1,2 @@ -from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader -from .meta import DatabaseType, PluginMeta -from .zip import MaubotZipImportError, ZippedPluginLoader +from .abc import BasePluginLoader, PluginLoader, PluginClass, IDConflictError, PluginMeta +from .zip import ZippedPluginLoader, MaubotZipImportError diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index c2c71b2..f148451 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,14 +13,17 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from typing import TYPE_CHECKING, TypeVar +from typing import TypeVar, Type, Dict, Set, List, TYPE_CHECKING from abc import ABC, abstractmethod import asyncio +from attr import dataclass +from packaging.version import Version, InvalidVersion + +from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer + +from ..__meta__ import __version__ from ..plugin_base import Plugin -from .meta import PluginMeta if TYPE_CHECKING: from ..instance import PluginInstance @@ -32,6 +35,36 @@ class IDConflictError(Exception): pass +@serializer(Version) +def serialize_version(version: Version) -> str: + return str(version) + + +@deserializer(Version) +def deserialize_version(version: str) -> Version: + try: + return Version(version) + except InvalidVersion as e: + raise SerializerError("Invalid version") from e + + +@dataclass +class PluginMeta(SerializableAttrs['PluginMeta']): + id: str + version: Version + modules: List[str] + main_class: str + + maubot: Version = Version(__version__) + database: bool = False + config: bool = False + webapp: bool = False + license: str = "" + extra_files: List[str] = [] + dependencies: List[str] = [] + soft_dependencies: List[str] = [] + + class BasePluginLoader(ABC): meta: PluginMeta @@ -47,25 +80,25 @@ class BasePluginLoader(ABC): async def read_file(self, path: str) -> bytes: pass - def sync_list_files(self, directory: str) -> list[str]: + def sync_list_files(self, directory: str) -> List[str]: raise NotImplementedError("This loader doesn't support synchronous operations") @abstractmethod - async def list_files(self, directory: str) -> list[str]: + async def list_files(self, directory: str) -> List[str]: pass class PluginLoader(BasePluginLoader, ABC): - id_cache: dict[str, PluginLoader] = {} + id_cache: Dict[str, 'PluginLoader'] = {} meta: PluginMeta - references: set[PluginInstance] + references: Set['PluginInstance'] def __init__(self): self.references = set() @classmethod - def find(cls, plugin_id: str) -> PluginLoader: + def find(cls, plugin_id: str) -> 'PluginLoader': return cls.id_cache[plugin_id] def to_dict(self) -> dict: @@ -76,21 +109,23 @@ class PluginLoader(BasePluginLoader, ABC): } async def stop_instances(self) -> None: - await asyncio.gather( - *[instance.stop() for instance in self.references if instance.started] - ) + await asyncio.gather(*[instance.stop() for instance + in self.references if instance.started]) async def start_instances(self) -> None: - await asyncio.gather( - *[instance.start() for instance in self.references if instance.enabled] - ) + await asyncio.gather(*[instance.start() for instance + in self.references if instance.enabled]) @abstractmethod - async def load(self) -> type[PluginClass]: + async def load(self) -> Type[PluginClass]: pass @abstractmethod - async def reload(self) -> type[PluginClass]: + async def reload(self) -> Type[PluginClass]: + pass + + @abstractmethod + async def unload(self) -> None: pass @abstractmethod diff --git a/maubot/loader/meta.py b/maubot/loader/meta.py deleted file mode 100644 index d368e24..0000000 --- a/maubot/loader/meta.py +++ /dev/null @@ -1,69 +0,0 @@ -# maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import List, Optional - -from attr import dataclass -from packaging.version import InvalidVersion, Version - -from mautrix.types import ( - ExtensibleEnum, - SerializableAttrs, - SerializerError, - deserializer, - serializer, -) - -from ..__meta__ import __version__ - - -@serializer(Version) -def serialize_version(version: Version) -> str: - return str(version) - - -@deserializer(Version) -def deserialize_version(version: str) -> Version: - try: - return Version(version) - except InvalidVersion as e: - raise SerializerError("Invalid version") from e - - -class DatabaseType(ExtensibleEnum): - SQLALCHEMY = "sqlalchemy" - ASYNCPG = "asyncpg" - - -@dataclass -class PluginMeta(SerializableAttrs): - id: str - version: Version - modules: List[str] - main_class: str - - maubot: Version = Version(__version__) - database: bool = False - database_type: DatabaseType = DatabaseType.SQLALCHEMY - config: bool = False - webapp: bool = False - license: str = "" - extra_files: List[str] = [] - dependencies: List[str] = [] - soft_dependencies: List[str] = [] - - @property - def database_type_str(self) -> Optional[str]: - return self.database_type.value if self.database else None diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index 8642183..6d8a8ce 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,27 +13,23 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - +from typing import Dict, List, Type, Tuple, Optional +from zipfile import ZipFile, BadZipFile from time import time -from zipfile import BadZipFile, ZipFile import logging -import os import sys +import os -from packaging.version import Version from ruamel.yaml import YAML, YAMLError +from packaging.version import Version from mautrix.types import SerializerError -from ..__meta__ import __version__ -from ..config import Config -from ..lib.zipimport import ZipImportError, zipimporter +from ..lib.zipimport import zipimporter, ZipImportError from ..plugin_base import Plugin -from .abc import IDConflictError, PluginClass, PluginLoader -from .meta import DatabaseType, PluginMeta +from ..config import Config +from .abc import PluginLoader, PluginClass, PluginMeta, IDConflictError -current_version = Version(__version__) yaml = YAML() @@ -54,25 +50,23 @@ class MaubotZipLoadError(MaubotZipImportError): class ZippedPluginLoader(PluginLoader): - path_cache: dict[str, ZippedPluginLoader] = {} + path_cache: Dict[str, 'ZippedPluginLoader'] = {} log: logging.Logger = logging.getLogger("maubot.loader.zip") trash_path: str = "delete" - directories: list[str] = [] + directories: List[str] = [] - path: str | None - meta: PluginMeta | None - main_class: str | None - main_module: str | None - _loaded: type[PluginClass] | None - _importer: zipimporter | None - _file: ZipFile | None + path: str + meta: PluginMeta + main_class: str + main_module: str + _loaded: Type[PluginClass] + _importer: zipimporter + _file: ZipFile def __init__(self, path: str) -> None: super().__init__() self.path = path self.meta = None - self.main_class = None - self.main_module = None self._loaded = None self._importer = None self._file = None @@ -81,8 +75,7 @@ class ZippedPluginLoader(PluginLoader): try: existing = self.id_cache[self.meta.id] raise IDConflictError( - f"Plugin with id {self.meta.id} already loaded from {existing.source}" - ) + f"Plugin with id {self.meta.id} already loaded from {existing.source}") except KeyError: pass self.path_cache[self.path] = self @@ -90,10 +83,13 @@ class ZippedPluginLoader(PluginLoader): self.log.debug(f"Preloaded plugin {self.meta.id} from {self.path}") def to_dict(self) -> dict: - return {**super().to_dict(), "path": self.path} + return { + **super().to_dict(), + "path": self.path + } @classmethod - def get(cls, path: str) -> ZippedPluginLoader: + def get(cls, path: str) -> 'ZippedPluginLoader': path = os.path.abspath(path) try: return cls.path_cache[path] @@ -105,12 +101,10 @@ class ZippedPluginLoader(PluginLoader): return self.path def __repr__(self) -> str: - return ( - "" - ) + return ("") def sync_read_file(self, path: str) -> bytes: return self._file.read(path) @@ -118,19 +112,16 @@ class ZippedPluginLoader(PluginLoader): async def read_file(self, path: str) -> bytes: return self.sync_read_file(path) - def sync_list_files(self, directory: str) -> list[str]: + def sync_list_files(self, directory: str) -> List[str]: directory = directory.rstrip("/") - return [ - file.filename - for file in self._file.filelist - if os.path.dirname(file.filename) == directory - ] + return [file.filename for file in self._file.filelist + if os.path.dirname(file.filename) == directory] - async def list_files(self, directory: str) -> list[str]: + async def list_files(self, directory: str) -> List[str]: return self.sync_list_files(directory) @staticmethod - def _read_meta(source) -> tuple[ZipFile, PluginMeta]: + def _read_meta(source) -> Tuple[ZipFile, PluginMeta]: try: file = ZipFile(source) data = file.read("maubot.yaml") @@ -148,16 +139,12 @@ class ZippedPluginLoader(PluginLoader): meta = PluginMeta.deserialize(meta_dict) except SerializerError as e: raise MaubotZipMetaError("Maubot plugin definition in file is invalid") from e - if meta.maubot > current_version: - raise MaubotZipMetaError( - f"Plugin requires maubot {meta.maubot}, but this instance is {current_version}" - ) return file, meta @classmethod - def verify_meta(cls, source) -> tuple[str, Version, DatabaseType | None]: + def verify_meta(cls, source) -> Tuple[str, Version]: _, meta = cls._read_meta(source) - return meta.id, meta.version, meta.database_type if meta.database else None + return meta.id, meta.version def _load_meta(self) -> None: file, meta = self._read_meta(self.path) @@ -167,7 +154,7 @@ class ZippedPluginLoader(PluginLoader): if "/" in meta.main_class: self.main_module, self.main_class = meta.main_class.split("/")[:2] else: - self.main_module = meta.modules[-1] + self.main_module = meta.modules[0] self.main_class = meta.main_class self._file = file @@ -186,24 +173,24 @@ class ZippedPluginLoader(PluginLoader): code = importer.get_code(self.main_module.replace(".", "/")) if self.main_class not in code.co_names: raise MaubotZipPreLoadError( - f"Main class {self.main_class} not in {self.main_module}" - ) + f"Main class {self.main_class} not in {self.main_module}") except ZipImportError as e: - raise MaubotZipPreLoadError(f"Main module {self.main_module} not found in file") from e + raise MaubotZipPreLoadError( + f"Main module {self.main_module} not found in file") from e for module in self.meta.modules: try: importer.find_module(module) except ZipImportError as e: raise MaubotZipPreLoadError(f"Module {module} not found in file") from e - async def load(self, reset_cache: bool = False) -> type[PluginClass]: + async def load(self, reset_cache: bool = False) -> Type[PluginClass]: try: return self._load(reset_cache) except MaubotZipImportError: self.log.exception(f"Failed to load {self.meta.id} v{self.meta.version}") raise - def _load(self, reset_cache: bool = False) -> type[PluginClass]: + def _load(self, reset_cache: bool = False) -> Type[PluginClass]: if self._loaded is not None and not reset_cache: return self._loaded self._load_meta() @@ -232,18 +219,13 @@ class ZippedPluginLoader(PluginLoader): self.log.debug(f"Loaded and imported plugin {self.meta.id} from {self.path}") return plugin - async def reload(self, new_path: str | None = None) -> type[PluginClass]: - self._unload() - if new_path is not None and new_path != self.path: - try: - del self.path_cache[self.path] - except KeyError: - pass + async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]: + await self.unload() + if new_path is not None: self.path = new_path - self.path_cache[self.path] = self return await self.load(reset_cache=True) - def _unload(self) -> None: + async def unload(self) -> None: for name, mod in list(sys.modules.items()): if (getattr(mod, "__file__", "") or "").startswith(self.path): del sys.modules[name] @@ -251,7 +233,7 @@ class ZippedPluginLoader(PluginLoader): self.log.debug(f"Unloaded plugin {self.meta.id} at {self.path}") async def delete(self) -> None: - self._unload() + await self.unload() try: del self.path_cache[self.path] except KeyError: @@ -269,22 +251,12 @@ class ZippedPluginLoader(PluginLoader): self.path = None @classmethod - def trash(cls, file_path: str, new_name: str | None = None, reason: str = "error") -> None: + def trash(cls, file_path: str, new_name: Optional[str] = None, reason: str = "error") -> None: if cls.trash_path == "delete": - try: - os.remove(file_path) - except FileNotFoundError: - pass + os.remove(file_path) else: new_name = new_name or f"{int(time())}-{reason}-{os.path.basename(file_path)}" - try: - os.rename(file_path, os.path.abspath(os.path.join(cls.trash_path, new_name))) - except OSError as e: - cls.log.warning(f"Failed to rename {file_path}: {e} - trying to delete") - try: - os.remove(file_path) - except FileNotFoundError: - pass + os.rename(file_path, os.path.abspath(os.path.join(cls.trash_path, new_name))) @classmethod def load_all(cls): diff --git a/maubot/management/api/__init__.py b/maubot/management/api/__init__.py index c2e5f24..5326039 100644 --- a/maubot/management/api/__init__.py +++ b/maubot/management/api/__init__.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,14 +13,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from aiohttp import web from asyncio import AbstractEventLoop import importlib -from aiohttp import web - from ...config import Config +from .base import routes, get_config, set_config, set_loop from .auth import check_token -from .base import get_config, routes, set_config from .middleware import auth, error @@ -31,15 +30,14 @@ def features(request: web.Request) -> web.Response: if err is None: return web.json_response(data) else: - return web.json_response( - { - "login": data["login"], - } - ) + return web.json_response({ + "login": data["login"], + }) def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: set_config(cfg) + set_loop(loop) for pkg, enabled in cfg["api_features"].items(): if enabled: importlib.import_module(f"maubot.management.api.{pkg}") diff --git a/maubot/management/api/auth.py b/maubot/management/api/auth.py index 0abc3ad..4675301 100644 --- a/maubot/management/api/auth.py +++ b/maubot/management/api/auth.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,8 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - +from typing import Optional from time import time from aiohttp import web @@ -22,7 +21,7 @@ from aiohttp import web from mautrix.types import UserID from mautrix.util.signed_token import sign_token, verify_token -from .base import get_config, routes +from .base import routes, get_config from .responses import resp @@ -34,25 +33,22 @@ def is_valid_token(token: str) -> bool: def create_token(user: UserID) -> str: - return sign_token( - get_config()["server.unshared_secret"], - { - "user_id": user, - "created_at": int(time()), - }, - ) + return sign_token(get_config()["server.unshared_secret"], { + "user_id": user, + "created_at": int(time()), + }) def get_token(request: web.Request) -> str: token = request.headers.get("Authorization", "") if not token or not token.startswith("Bearer "): - token = request.query.get("access_token", "") + token = request.query.get("access_token", None) else: - token = token[len("Bearer ") :] + token = token[len("Bearer "):] return token -def check_token(request: web.Request) -> web.Response | None: +def check_token(request: web.Request) -> Optional[web.Response]: token = get_token(request) if not token: return resp.no_token diff --git a/maubot/management/api/base.py b/maubot/management/api/base.py index 3d7693a..b6a5dea 100644 --- a/maubot/management/api/base.py +++ b/maubot/management/api/base.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,17 +13,15 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -import asyncio - from aiohttp import web +import asyncio from ...__meta__ import __version__ from ...config import Config routes: web.RouteTableDef = web.RouteTableDef() -_config: Config | None = None +_config: Config = None +_loop: asyncio.AbstractEventLoop = None def set_config(config: Config) -> None: @@ -35,6 +33,17 @@ def get_config() -> Config: return _config +def set_loop(loop: asyncio.AbstractEventLoop) -> None: + global _loop + _loop = loop + + +def get_loop() -> asyncio.AbstractEventLoop: + return _loop + + @routes.get("/version") async def version(_: web.Request) -> web.Response: - return web.json_response({"version": __version__}) + return web.json_response({ + "version": __version__ + }) diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py index d2ad35d..0585d63 100644 --- a/maubot/management/api/client.py +++ b/maubot/management/api/client.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,23 +13,20 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - +from typing import Optional from json import JSONDecodeError -import logging from aiohttp import web +from mautrix.types import UserID, SyncToken, FilterID +from mautrix.errors import MatrixRequestError, MatrixConnectionError, MatrixInvalidToken from mautrix.client import Client as MatrixClient -from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequestError -from mautrix.types import FilterID, SyncToken, UserID +from ...db import DBClient from ...client import Client from .base import routes from .responses import resp -log = logging.getLogger("maubot.server.client") - @routes.get("/clients") async def get_clients(_: web.Request) -> web.Response: @@ -39,94 +36,64 @@ async def get_clients(_: web.Request) -> web.Response: @routes.get("/client/{id}") async def get_client(request: web.Request) -> web.Response: user_id = request.match_info.get("id", None) - client = await Client.get(user_id) + client = Client.get(user_id, None) if not client: return resp.client_not_found return resp.found(client.to_dict()) -async def _create_client(user_id: UserID | None, data: dict) -> web.Response: +async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response: homeserver = data.get("homeserver", None) access_token = data.get("access_token", None) - device_id = data.get("device_id", None) - new_client = MatrixClient( - mxid="@not:a.mxid", - base_url=homeserver, - token=access_token, - client_session=Client.http_client, - ) + new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token, + loop=Client.loop, client_session=Client.http_client) try: - whoami = await new_client.whoami() - except MatrixInvalidToken as e: + mxid = await new_client.whoami() + except MatrixInvalidToken: return resp.bad_client_access_token except MatrixRequestError: - log.warning(f"Failed to get whoami from {homeserver} for new client", exc_info=True) return resp.bad_client_access_details except MatrixConnectionError: - log.warning(f"Failed to connect to {homeserver} for new client", exc_info=True) return resp.bad_client_connection_details if user_id is None: - existing_client = await Client.get(whoami.user_id) + existing_client = Client.get(mxid, None) if existing_client is not None: return resp.user_exists - elif whoami.user_id != user_id: - return resp.mxid_mismatch(whoami.user_id) - elif whoami.device_id and device_id and whoami.device_id != device_id: - return resp.device_id_mismatch(whoami.device_id) - client = await Client.get( - whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id - ) - client.enabled = data.get("enabled", True) - client.sync = data.get("sync", True) - await client.update_autojoin(data.get("autojoin", True), save=False) - await client.update_online(data.get("online", True), save=False) - client.displayname = data.get("displayname", "disable") - client.avatar_url = data.get("avatar_url", "disable") - await client.update() + elif mxid != user_id: + return resp.mxid_mismatch(mxid) + db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token, + enabled=data.get("enabled", True), next_batch=SyncToken(""), + filter_id=FilterID(""), sync=data.get("sync", True), + autojoin=data.get("autojoin", True), online=data.get("online", True), + displayname=data.get("displayname", ""), + avatar_url=data.get("avatar_url", "")) + client = Client(db_instance) + client.db_instance.insert() await client.start() return resp.created(client.to_dict()) -async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response: +async def _update_client(client: Client, data: dict) -> web.Response: try: - await client.update_access_details( - data.get("access_token"), data.get("homeserver"), data.get("device_id") - ) + await client.update_access_details(data.get("access_token", None), + data.get("homeserver", None)) except MatrixInvalidToken: return resp.bad_client_access_token except MatrixRequestError: - log.warning( - f"Failed to get whoami from homeserver to update client details", exc_info=True - ) return resp.bad_client_access_details except MatrixConnectionError: - log.warning(f"Failed to connect to homeserver to update client details", exc_info=True) return resp.bad_client_connection_details except ValueError as e: - str_err = str(e) - if str_err.startswith("MXID mismatch"): - return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :]) - elif str_err.startswith("Device ID mismatch"): - return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :]) - await client.update_avatar_url(data.get("avatar_url"), save=False) - await client.update_displayname(data.get("displayname"), save=False) - await client.update_started(data.get("started")) - await client.update_enabled(data.get("enabled"), save=False) - await client.update_autojoin(data.get("autojoin"), save=False) - await client.update_online(data.get("online"), save=False) - await client.update_sync(data.get("sync"), save=False) - await client.update() - return resp.updated(client.to_dict(), is_login=is_login) - - -async def _create_or_update_client( - user_id: UserID, data: dict, is_login: bool = False -) -> web.Response: - client = await Client.get(user_id) - if not client: - return await _create_client(user_id, data) - else: - return await _update_client(client, data, is_login=is_login) + return resp.mxid_mismatch(str(e)[len("MXID mismatch: "):]) + with client.db_instance.edit_mode(): + await client.update_avatar_url(data.get("avatar_url", None)) + await client.update_displayname(data.get("displayname", None)) + await client.update_started(data.get("started", None)) + client.enabled = data.get("enabled", client.enabled) + client.autojoin = data.get("autojoin", client.autojoin) + client.online = data.get("online", client.online) + client.sync = data.get("sync", client.sync) + return resp.updated(client.to_dict()) @routes.post("/client/new") @@ -140,33 +107,37 @@ async def create_client(request: web.Request) -> web.Response: @routes.put("/client/{id}") async def update_client(request: web.Request) -> web.Response: - user_id = request.match_info["id"] + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) try: data = await request.json() except JSONDecodeError: return resp.body_not_json - return await _create_or_update_client(user_id, data) + if not client: + return await _create_client(user_id, data) + else: + return await _update_client(client, data) @routes.delete("/client/{id}") async def delete_client(request: web.Request) -> web.Response: - user_id = request.match_info["id"] - client = await Client.get(user_id) + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) if not client: return resp.client_not_found if len(client.references) > 0: return resp.client_in_use if client.started: await client.stop() - await client.delete() + client.delete() return resp.deleted @routes.post("/client/{id}/clearcache") async def clear_client_cache(request: web.Request) -> web.Response: - user_id = request.match_info["id"] - client = await Client.get(user_id) + user_id = request.match_info.get("id", None) + client = Client.get(user_id, None) if not client: return resp.client_not_found - await client.clear_cache() + client.clear_cache() return resp.ok diff --git a/maubot/management/api/client_auth.py b/maubot/management/api/client_auth.py index 4e5e201..36c4d9a 100644 --- a/maubot/management/api/client_auth.py +++ b/maubot/management/api/client_auth.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,105 +13,66 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from typing import NamedTuple -from http import HTTPStatus +from typing import Dict, Tuple, NamedTuple, Optional from json import JSONDecodeError -import asyncio +from http import HTTPStatus import hashlib -import hmac import random import string +import hmac from aiohttp import web -from yarl import URL - -from mautrix.api import Method, Path, SynapseAdminPath -from mautrix.client import ClientAPI +from mautrix.api import HTTPAPI, Path, SynapseAdminPath, Method from mautrix.errors import MatrixRequestError -from mautrix.types import LoginResponse, LoginType -from .base import get_config, routes -from .client import _create_client, _create_or_update_client +from .base import routes, get_config, get_loop from .responses import resp -def known_homeservers() -> dict[str, dict[str, str]]: - return get_config()["homeservers"] +def registration_secrets() -> Dict[str, Dict[str, str]]: + return get_config()["registration_secrets"] @routes.get("/client/auth/servers") -async def get_known_servers(_: web.Request) -> web.Response: - return web.json_response({key: value["url"] for key, value in known_homeservers().items()}) +async def get_registerable_servers(_: web.Request) -> web.Response: + return web.json_response({key: value["url"] for key, value in registration_secrets().items()}) -class AuthRequestInfo(NamedTuple): - server_name: str - client: ClientAPI - secret: str - username: str - password: str - user_type: str - device_name: str - update_client: bool - sso: bool +AuthRequestInfo = NamedTuple("AuthRequestInfo", api=HTTPAPI, secret=str, username=str, + password=str, user_type=str) -truthy_strings = ("1", "true", "yes") - - -async def read_client_auth_request( - request: web.Request, -) -> tuple[AuthRequestInfo | None, web.Response | None]: +async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo], + Optional[web.Response]]: server_name = request.match_info.get("server", None) - server = known_homeservers().get(server_name, None) + server = registration_secrets().get(server_name, None) if not server: return None, resp.server_not_found try: body = await request.json() except JSONDecodeError: return None, resp.body_not_json - sso = request.query.get("sso", "").lower() in truthy_strings try: username = body["username"] password = body["password"] except KeyError: - if not sso: - return None, resp.username_or_password_missing - username = password = None + return None, resp.username_or_password_missing try: base_url = server["url"] + secret = server["secret"] except KeyError: return None, resp.invalid_server - return ( - AuthRequestInfo( - server_name=server_name, - client=ClientAPI(base_url=base_url), - secret=server.get("secret"), - username=username, - password=password, - user_type=body.get("user_type", "bot"), - device_name=body.get("device_name", "Maubot"), - update_client=request.query.get("update_client", "").lower() in truthy_strings, - sso=sso, - ), - None, - ) + api = HTTPAPI(base_url, "", loop=get_loop()) + user_type = body.get("user_type", "bot") + return AuthRequestInfo(api, secret, username, password, user_type), None -def generate_mac( - secret: str, - nonce: str, - username: str, - password: str, - admin: bool = False, - user_type: str = None, -) -> str: +def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False, + user_type: str = None) -> str: mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1) mac.update(nonce.encode("utf-8")) mac.update(b"\x00") - mac.update(username.encode("utf-8")) + mac.update(user.encode("utf-8")) mac.update(b"\x00") mac.update(password.encode("utf-8")) mac.update(b"\x00") @@ -124,150 +85,49 @@ def generate_mac( @routes.post("/client/auth/{server}/register") async def register(request: web.Request) -> web.Response: - req, err = await read_client_auth_request(request) + info, err = await read_client_auth_request(request) if err is not None: return err - if req.sso: - return resp.registration_no_sso - elif not req.secret: - return resp.registration_secret_not_found + api, secret, username, password, user_type = info path = SynapseAdminPath.v1.register - res = await req.client.api.request(Method.GET, path) + res = await api.request(Method.GET, path) content = { "nonce": res["nonce"], - "username": req.username, - "password": req.password, + "username": username, + "password": password, "admin": False, - "user_type": req.user_type, + "mac": generate_mac(secret, res["nonce"], username, password, user_type=user_type), + "user_type": user_type, } - content["mac"] = generate_mac(**content, secret=req.secret) try: - raw_res = await req.client.api.request(Method.POST, path, content=content) + return web.json_response(await api.request(Method.POST, path, content=content)) except MatrixRequestError as e: - return web.json_response( - { - "errcode": e.errcode, - "error": e.message, - "http_status": e.http_status, - }, - status=HTTPStatus.INTERNAL_SERVER_ERROR, - ) - login_res = LoginResponse.deserialize(raw_res) - if req.update_client: - return await _create_client( - login_res.user_id, - { - "homeserver": str(req.client.api.base_url), - "access_token": login_res.access_token, - "device_id": login_res.device_id, - }, - ) - return web.json_response(login_res.serialize()) + return web.json_response({ + "errcode": e.errcode, + "error": e.message, + "http_status": e.http_status, + }, status=HTTPStatus.INTERNAL_SERVER_ERROR) @routes.post("/client/auth/{server}/login") async def login(request: web.Request) -> web.Response: - req, err = await read_client_auth_request(request) + info, err = await read_client_auth_request(request) if err is not None: return err - if req.sso: - return await _do_sso(req) - else: - return await _do_login(req) - - -async def _do_sso(req: AuthRequestInfo) -> web.Response: - flows = await req.client.get_login_flows() - if not flows.supports_type(LoginType.SSO): - return resp.sso_not_supported - waiter_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=16)) - cfg = get_config() - public_url = ( - URL(cfg["server.public_url"]) - / "_matrix/maubot/v1/client/auth_external_sso/complete" - / waiter_id - ) - sso_url = req.client.api.base_url.with_path(str(Path.v3.login.sso.redirect)).with_query( - {"redirectUrl": str(public_url)} - ) - sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future() - return web.json_response({"sso_url": str(sso_url), "id": waiter_id}) - - -async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response: - device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) - device_id = f"maubot_{device_id}" + api, _, username, password, _ = info + device_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) try: - if req.sso: - res = await req.client.login( - token=login_token, - login_type=LoginType.TOKEN, - device_id=device_id, - store_access_token=False, - initial_device_display_name=req.device_name, - ) - else: - res = await req.client.login( - identifier=req.username, - login_type=LoginType.PASSWORD, - password=req.password, - device_id=device_id, - initial_device_display_name=req.device_name, - store_access_token=False, - ) + return web.json_response(await api.request(Method.POST, Path.login, content={ + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": username, + }, + "password": password, + "device_id": f"maubot_{device_id}", + })) except MatrixRequestError as e: - return web.json_response( - { - "errcode": e.errcode, - "error": e.message, - }, - status=e.http_status, - ) - if req.update_client: - return await _create_or_update_client( - res.user_id, - { - "homeserver": str(req.client.api.base_url), - "access_token": res.access_token, - "device_id": res.device_id, - }, - is_login=True, - ) - return web.json_response(res.serialize()) - - -sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {} - - -@routes.post("/client/auth/{server}/sso/{id}/wait") -async def wait_sso(request: web.Request) -> web.Response: - waiter_id = request.match_info["id"] - req, fut = sso_waiters[waiter_id] - try: - login_token = await fut - finally: - sso_waiters.pop(waiter_id, None) - return await _do_login(req, login_token) - - -@routes.get("/client/auth_external_sso/complete/{id}") -async def complete_sso(request: web.Request) -> web.Response: - try: - _, fut = sso_waiters[request.match_info["id"]] - except KeyError: - return web.Response(status=404, text="Invalid session ID\n") - if fut.cancelled(): - return web.Response(status=200, text="The login was cancelled from the Maubot client\n") - elif fut.done(): - return web.Response(status=200, text="The login token was already received\n") - try: - fut.set_result(request.query["loginToken"]) - except KeyError: - return web.Response(status=400, text="Missing loginToken query parameter\n") - except asyncio.InvalidStateError: - return web.Response(status=500, text="Invalid state\n") - return web.Response( - status=200, - text="Login token received, please return to your Maubot client. " - "This tab can be closed.\n", - ) + return web.json_response({ + "errcode": e.errcode, + "error": e.message, + }, status=e.http_status) diff --git a/maubot/management/api/client_proxy.py b/maubot/management/api/client_proxy.py index 3fa682b..8c293cd 100644 --- a/maubot/management/api/client_proxy.py +++ b/maubot/management/api/client_proxy.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from aiohttp import client as http, web +from aiohttp import web, client as http from ...client import Client from .base import routes @@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024 @routes.view("/proxy/{id}/{path:_matrix/.+}") async def proxy(request: web.Request) -> web.StreamResponse: user_id = request.match_info.get("id", None) - client = await Client.get(user_id) + client = Client.get(user_id, None) if not client: return resp.client_not_found @@ -45,9 +45,8 @@ async def proxy(request: web.Request) -> web.StreamResponse: headers["X-Forwarded-For"] = f"{host}:{port}" data = await request.read() - async with http.request( - request.method, f"{client.homeserver}/{path}", headers=headers, params=query, data=data - ) as proxy_resp: + async with http.request(request.method, f"{client.homeserver}/{path}", headers=headers, + params=query, data=data) as proxy_resp: response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers) await response.prepare(request) async for chunk in proxy_resp.content.iter_chunked(PROXY_CHUNK_SIZE): diff --git a/maubot/management/api/dev_open.py b/maubot/management/api/dev_open.py index 2881d46..323c515 100644 --- a/maubot/management/api/dev_open.py +++ b/maubot/management/api/dev_open.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,11 +14,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from string import Template -import asyncio +from subprocess import run import re -from aiohttp import web from ruamel.yaml import YAML +from aiohttp import web from .base import routes @@ -27,7 +27,9 @@ enabled = False @routes.get("/debug/open") async def check_enabled(_: web.Request) -> web.Response: - return web.json_response({"enabled": enabled}) + return web.json_response({ + "enabled": enabled, + }) try: @@ -38,6 +40,7 @@ try: editor_command = Template(cfg["editor"]) pathmap = [(re.compile(item["find"]), item["replace"]) for item in cfg["pathmap"]] + @routes.post("/debug/open") async def open_file(request: web.Request) -> web.Response: data = await request.json() @@ -48,9 +51,13 @@ try: cmd = editor_command.substitute(path=path, line=data["line"]) except (KeyError, ValueError): return web.Response(status=400) - res = await asyncio.create_subprocess_shell(cmd) - stdout, stderr = await res.communicate() - return web.json_response({"return": res.returncode, "stdout": stdout, "stderr": stderr}) + res = run(cmd, shell=True) + return web.json_response({ + "return": res.returncode, + "stdout": res.stdout, + "stderr": res.stderr + }) + enabled = True except Exception: diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py index 4043221..91861af 100644 --- a/maubot/management/api/instance.py +++ b/maubot/management/api/instance.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -17,9 +17,10 @@ from json import JSONDecodeError from aiohttp import web -from ...client import Client +from ...db import DBPlugin from ...instance import PluginInstance from ...loader import PluginLoader +from ...client import Client from .base import routes from .responses import resp @@ -31,50 +32,51 @@ async def get_instances(_: web.Request) -> web.Response: @routes.get("/instance/{id}") async def get_instance(request: web.Request) -> web.Response: - instance_id = request.match_info["id"].lower() - instance = await PluginInstance.get(instance_id) + instance_id = request.match_info.get("id", "").lower() + instance = PluginInstance.get(instance_id, None) if not instance: return resp.instance_not_found return resp.found(instance.to_dict()) async def _create_instance(instance_id: str, data: dict) -> web.Response: - plugin_type = data.get("type") - primary_user = data.get("primary_user") + plugin_type = data.get("type", None) + primary_user = data.get("primary_user", None) if not plugin_type: return resp.plugin_type_required elif not primary_user: return resp.primary_user_required - elif not await Client.get(primary_user): + elif not Client.get(primary_user): return resp.primary_user_not_found try: PluginLoader.find(plugin_type) except KeyError: return resp.plugin_type_not_found - instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user) - instance.enabled = data.get("enabled", True) - instance.config_str = data.get("config") or "" - await instance.update() - await instance.load() + db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True), + primary_user=primary_user, config=data.get("config", "")) + instance = PluginInstance(db_instance) + instance.load() + instance.db_instance.insert() await instance.start() return resp.created(instance.to_dict()) async def _update_instance(instance: PluginInstance, data: dict) -> web.Response: - if not await instance.update_primary_user(data.get("primary_user")): + if not await instance.update_primary_user(data.get("primary_user", None)): return resp.primary_user_not_found - await instance.update_id(data.get("id")) - await instance.update_enabled(data.get("enabled")) - await instance.update_config(data.get("config")) - await instance.update_started(data.get("started")) - await instance.update_type(data.get("type")) - return resp.updated(instance.to_dict()) + with instance.db_instance.edit_mode(): + instance.update_id(data.get("id", None)) + instance.update_enabled(data.get("enabled", None)) + instance.update_config(data.get("config", None)) + await instance.update_started(data.get("started", None)) + await instance.update_type(data.get("type", None)) + return resp.updated(instance.to_dict()) @routes.put("/instance/{id}") async def update_instance(request: web.Request) -> web.Response: - instance_id = request.match_info["id"].lower() - instance = await PluginInstance.get(instance_id) + instance_id = request.match_info.get("id", "").lower() + instance = PluginInstance.get(instance_id, None) try: data = await request.json() except JSONDecodeError: @@ -87,11 +89,11 @@ async def update_instance(request: web.Request) -> web.Response: @routes.delete("/instance/{id}") async def delete_instance(request: web.Request) -> web.Response: - instance_id = request.match_info["id"].lower() - instance = await PluginInstance.get(instance_id) + instance_id = request.match_info.get("id", "").lower() + instance = PluginInstance.get(instance_id) if not instance: return resp.instance_not_found if instance.started: await instance.stop() - await instance.delete() + instance.delete() return resp.deleted diff --git a/maubot/management/api/instance_database.py b/maubot/management/api/instance_database.py index 2f8c37a..bc3baf3 100644 --- a/maubot/management/api/instance_database.py +++ b/maubot/management/api/instance_database.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,67 +13,80 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - +from typing import Union, TYPE_CHECKING from datetime import datetime from aiohttp import web -from asyncpg import PostgresError -import aiosqlite - -from mautrix.util.async_db import Database +from sqlalchemy import Table, Column, asc, desc, exc +from sqlalchemy.orm import Query +from sqlalchemy.engine.result import ResultProxy, RowProxy from ...instance import PluginInstance -from ...lib.optionalalchemy import Engine, IntegrityError, OperationalError, asc, desc from .base import routes from .responses import resp @routes.get("/instance/{id}/database") async def get_database(request: web.Request) -> web.Response: - instance_id = request.match_info["id"].lower() - instance = await PluginInstance.get(instance_id) + instance_id = request.match_info.get("id", "") + instance = PluginInstance.get(instance_id, None) if not instance: return resp.instance_not_found elif not instance.inst_db: return resp.plugin_has_no_database - return web.json_response(await instance.get_db_tables()) + if TYPE_CHECKING: + table: Table + column: Column + return web.json_response({ + table.name: { + "columns": { + column.name: { + "type": str(column.type), + "unique": column.unique or False, + "default": column.default, + "nullable": column.nullable, + "primary": column.primary_key, + "autoincrement": column.autoincrement, + } for column in table.columns + }, + } for table in instance.get_db_tables().values() + }) + + +def check_type(val): + if isinstance(val, datetime): + return val.isoformat() + return val @routes.get("/instance/{id}/database/{table}") async def get_table(request: web.Request) -> web.Response: - instance_id = request.match_info["id"].lower() - instance = await PluginInstance.get(instance_id) + instance_id = request.match_info.get("id", "") + instance = PluginInstance.get(instance_id, None) if not instance: return resp.instance_not_found elif not instance.inst_db: return resp.plugin_has_no_database - tables = await instance.get_db_tables() + tables = instance.get_db_tables() try: table = tables[request.match_info.get("table", "")] except KeyError: return resp.table_not_found try: order = [tuple(order.split(":")) for order in request.query.getall("order")] - order = [ - ( - (asc if sort.lower() == "asc" else desc)(table.columns[column]) - if sort - else table.columns[column] - ) - for column, sort in order - ] + order = [(asc if sort.lower() == "asc" else desc)(table.columns[column]) + if sort else table.columns[column] + for column, sort in order] except KeyError: order = [] - limit = int(request.query.get("limit", "100")) - if isinstance(instance.inst_db, Engine): - return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit)) + limit = int(request.query.get("limit", 100)) + return execute_query(instance, table.select().order_by(*order).limit(limit)) @routes.post("/instance/{id}/database/query") async def query(request: web.Request) -> web.Response: - instance_id = request.match_info["id"].lower() - instance = await PluginInstance.get(instance_id) + instance_id = request.match_info.get("id", "") + instance = PluginInstance.get(instance_id, None) if not instance: return resp.instance_not_found elif not instance.inst_db: @@ -83,76 +96,28 @@ async def query(request: web.Request) -> web.Response: sql_query = data["query"] except KeyError: return resp.query_missing - rows_as_dict = data.get("rows_as_dict", False) - if isinstance(instance.inst_db, Engine): - return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict) - elif isinstance(instance.inst_db, Database): - try: - return await _execute_query_asyncpg(instance, sql_query, rows_as_dict) - except (PostgresError, aiosqlite.Error) as e: - return resp.sql_error(e, sql_query) - else: - return resp.unsupported_plugin_database + return execute_query(instance, sql_query, + rows_as_dict=data.get("rows_as_dict", False)) -def check_type(val): - if isinstance(val, datetime): - return val.isoformat() - return val - - -async def _execute_query_asyncpg( - instance: PluginInstance, sql_query: str, rows_as_dict: bool = False -) -> web.Response: - data = {"ok": True, "query": sql_query} - if sql_query.upper().startswith("SELECT"): - res = await instance.inst_db.fetch(sql_query) - data["rows"] = [ - ( - {key: check_type(value) for key, value in row.items()} - if rows_as_dict - else [check_type(value) for value in row] - ) - for row in res - ] - if len(res) > 0: - # TODO can we find column names when there are no rows? - data["columns"] = list(res[0].keys()) - else: - res = await instance.inst_db.execute(sql_query) - if isinstance(res, str): - data["status_msg"] = res - elif isinstance(res, aiosqlite.Cursor): - data["rowcount"] = res.rowcount - # data["inserted_primary_key"] = res.lastrowid - else: - data["status_msg"] = "unknown status" - return web.json_response(data) - - -def _execute_query_sqlalchemy( - instance: PluginInstance, sql_query: str, rows_as_dict: bool = False -) -> web.Response: - assert isinstance(instance.inst_db, Engine) +def execute_query(instance: PluginInstance, sql_query: Union[str, Query], + rows_as_dict: bool = False) -> web.Response: try: - res = instance.inst_db.execute(sql_query) - except IntegrityError as e: + res: ResultProxy = instance.inst_db.execute(sql_query) + except exc.IntegrityError as e: return resp.sql_integrity_error(e, sql_query) - except OperationalError as e: + except exc.OperationalError as e: return resp.sql_operational_error(e, sql_query) data = { "ok": True, "query": str(sql_query), } if res.returns_rows: - data["rows"] = [ - ( - {key: check_type(value) for key, value in row.items()} - if rows_as_dict - else [check_type(value) for value in row] - ) - for row in res - ] + row: RowProxy + data["rows"] = [({key: check_type(value) for key, value in row.items()} + if rows_as_dict + else [check_type(value) for value in row]) + for row in res] data["columns"] = res.keys() else: data["rowcount"] = res.rowcount diff --git a/maubot/management/api/log.py b/maubot/management/api/log.py index 14c80cd..d6ec092 100644 --- a/maubot/management/api/log.py +++ b/maubot/management/api/log.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,62 +13,31 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from collections import deque +from typing import Deque, List from datetime import datetime -import asyncio +from collections import deque import logging +import asyncio -from aiohttp import web, web_ws - -from mautrix.util import background_task +from aiohttp import web +from .base import routes, get_loop from .auth import is_valid_token -from .base import routes -BUILTIN_ATTRS = { - "args", - "asctime", - "created", - "exc_info", - "exc_text", - "filename", - "funcName", - "levelname", - "levelno", - "lineno", - "module", - "msecs", - "message", - "msg", - "name", - "pathname", - "process", - "processName", - "relativeCreated", - "stack_info", - "thread", - "threadName", -} -INCLUDE_ATTRS = { - "filename", - "funcName", - "levelname", - "levelno", - "lineno", - "module", - "name", - "pathname", -} +BUILTIN_ATTRS = {"args", "asctime", "created", "exc_info", "exc_text", "filename", "funcName", + "levelname", "levelno", "lineno", "module", "msecs", "message", "msg", "name", + "pathname", "process", "processName", "relativeCreated", "stack_info", "thread", + "threadName"} +INCLUDE_ATTRS = {"filename", "funcName", "levelname", "levelno", "lineno", "module", "name", + "pathname"} EXCLUDE_ATTRS = BUILTIN_ATTRS - INCLUDE_ATTRS MAX_LINES = 2048 class LogCollector(logging.Handler): - lines: deque[dict] + lines: Deque[dict] formatter: logging.Formatter - listeners: list[web.WebSocketResponse] + listeners: List[web.WebSocketResponse] loop: asyncio.AbstractEventLoop def __init__(self, level=logging.NOTSET) -> None: @@ -87,7 +56,9 @@ class LogCollector(logging.Handler): # JSON conversion based on Marsel Mavletkulov's json-log-formatter (MIT license) # https://github.com/marselester/json-log-formatter content = { - name: value for name, value in record.__dict__.items() if name not in EXCLUDE_ATTRS + name: value + for name, value in record.__dict__.items() + if name not in EXCLUDE_ATTRS } content["id"] = str(record.relativeCreated) content["msg"] = record.getMessage() @@ -111,18 +82,18 @@ class LogCollector(logging.Handler): handler = LogCollector() +log_root = logging.getLogger("maubot") log = logging.getLogger("maubot.server.websocket") sockets = [] def init(loop: asyncio.AbstractEventLoop) -> None: - logging.root.addHandler(handler) + log_root.addHandler(handler) handler.loop = loop async def stop_all() -> None: - log.debug("Closing log listener websockets") - logging.root.removeHandler(handler) + log_root.removeHandler(handler) for socket in sockets: try: await socket.close(code=1012) @@ -139,15 +110,14 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse: authenticated = False async def close_if_not_authenticated(): - await asyncio.sleep(5) + await asyncio.sleep(5, loop=get_loop()) if not authenticated: await ws.close(code=4000) log.debug(f"Connection from {request.remote} terminated due to no authentication") - background_task.create(close_if_not_authenticated()) + asyncio.ensure_future(close_if_not_authenticated()) try: - msg: web_ws.WSMessage async for msg in ws: if msg.type != web.WSMsgType.TEXT: continue diff --git a/maubot/management/api/login.py b/maubot/management/api/login.py index bfb2f6a..21f9342 100644 --- a/maubot/management/api/login.py +++ b/maubot/management/api/login.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -17,10 +17,9 @@ import json from aiohttp import web -from .auth import create_token -from .base import get_config, routes +from .base import routes, get_config from .responses import resp - +from .auth import create_token @routes.post("/auth/login") async def login(request: web.Request) -> web.Response: diff --git a/maubot/management/api/middleware.py b/maubot/management/api/middleware.py index 17141fa..ff6b4c1 100644 --- a/maubot/management/api/middleware.py +++ b/maubot/management/api/middleware.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,15 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Awaitable, Callable -import base64 +from typing import Callable, Awaitable import logging from aiohttp import web +from .responses import resp from .auth import check_token from .base import get_config -from .responses import resp Handler = Callable[[web.Request], Awaitable[web.Response]] log = logging.getLogger("maubot.server") @@ -29,13 +28,8 @@ log = logging.getLogger("maubot.server") @web.middleware async def auth(request: web.Request, handler: Handler) -> web.Response: - subpath = request.path[len("/_matrix/maubot/v1") :] - if ( - subpath.startswith("/auth/") - or subpath.startswith("/client/auth_external_sso/complete/") - or subpath == "/features" - or subpath == "/logs" - ): + subpath = request.path[len(get_config()["server.base_path"]):] + if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs": return await handler(request) err = check_token(request) if err is not None: @@ -52,18 +46,10 @@ async def error(request: web.Request, handler: Handler) -> web.Response: return resp.path_not_found elif ex.status_code == 405: return resp.method_not_allowed - return web.json_response( - { - "httpexception": { - "headers": {key: value for key, value in ex.headers.items()}, - "class": type(ex).__name__, - "body": ex.text or base64.b64encode(ex.body), - }, - "error": f"Unhandled HTTP {ex.status}: {ex.text[:128] or 'non-text response'}", - "errcode": f"unhandled_http_{ex.status}", - }, - status=ex.status, - ) + return web.json_response({ + "error": f"Unhandled HTTP {ex.status}", + "errcode": f"unhandled_http_{ex.status}", + }, status=ex.status) except Exception: log.exception("Error in handler") return resp.internal_server_error diff --git a/maubot/management/api/plugin.py b/maubot/management/api/plugin.py index 94d8d9d..4429e11 100644 --- a/maubot/management/api/plugin.py +++ b/maubot/management/api/plugin.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -17,9 +17,9 @@ import traceback from aiohttp import web -from ...loader import MaubotZipImportError, PluginLoader -from .base import routes +from ...loader import PluginLoader, MaubotZipImportError from .responses import resp +from .base import routes @routes.get("/plugins") @@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response: @routes.get("/plugin/{id}") async def get_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info["id"] - plugin = PluginLoader.id_cache.get(plugin_id) + plugin_id = request.match_info.get("id", None) + plugin = PluginLoader.id_cache.get(plugin_id, None) if not plugin: return resp.plugin_not_found return resp.found(plugin.to_dict()) @@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response: @routes.delete("/plugin/{id}") async def delete_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info["id"] - plugin = PluginLoader.id_cache.get(plugin_id) + plugin_id = request.match_info.get("id", None) + plugin = PluginLoader.id_cache.get(plugin_id, None) if not plugin: return resp.plugin_not_found elif len(plugin.references) > 0: @@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response: @routes.post("/plugin/{id}/reload") async def reload_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info["id"] - plugin = PluginLoader.id_cache.get(plugin_id) + plugin_id = request.match_info.get("id", None) + plugin = PluginLoader.id_cache.get(plugin_id, None) if not plugin: return resp.plugin_not_found diff --git a/maubot/management/api/plugin_upload.py b/maubot/management/api/plugin_upload.py index 4cd2c47..7b5b5de 100644 --- a/maubot/management/api/plugin_upload.py +++ b/maubot/management/api/plugin_upload.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -15,39 +15,27 @@ # along with this program. If not, see . from io import BytesIO from time import time -import logging +import traceback import os.path import re -import traceback from aiohttp import web from packaging.version import Version -from ...loader import DatabaseType, MaubotZipImportError, PluginLoader, ZippedPluginLoader -from .base import get_config, routes +from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError from .responses import resp - -try: - import sqlalchemy - - has_alchemy = True -except ImportError: - has_alchemy = False - -log = logging.getLogger("maubot.server.upload") +from .base import routes, get_config @routes.put("/plugin/{id}") async def put_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info["id"] + plugin_id = request.match_info.get("id", None) content = await request.read() file = BytesIO(content) try: - pid, version, db_type = ZippedPluginLoader.verify_meta(file) + pid, version = ZippedPluginLoader.verify_meta(file) except MaubotZipImportError as e: return resp.plugin_import_error(str(e), traceback.format_exc()) - if db_type == DatabaseType.SQLALCHEMY and not has_alchemy: - return resp.sqlalchemy_not_installed if pid != plugin_id: return resp.pid_mismatch plugin = PluginLoader.id_cache.get(plugin_id, None) @@ -64,11 +52,9 @@ async def upload_plugin(request: web.Request) -> web.Response: content = await request.read() file = BytesIO(content) try: - pid, version, db_type = ZippedPluginLoader.verify_meta(file) + pid, version = ZippedPluginLoader.verify_meta(file) except MaubotZipImportError as e: return resp.plugin_import_error(str(e), traceback.format_exc()) - if db_type == DatabaseType.SQLALCHEMY and not has_alchemy: - return resp.sqlalchemy_not_installed plugin = PluginLoader.id_cache.get(pid, None) if not plugin: return await upload_new_plugin(content, pid, version) @@ -92,20 +78,15 @@ async def upload_new_plugin(content: bytes, pid: str, version: Version) -> web.R return resp.created(plugin.to_dict()) -async def upload_replacement_plugin( - plugin: ZippedPluginLoader, content: bytes, new_version: Version -) -> web.Response: +async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, + new_version: Version) -> web.Response: dirname = os.path.dirname(plugin.path) old_filename = os.path.basename(plugin.path) if str(plugin.meta.version) in old_filename: - replacement = ( - str(new_version) - if plugin.meta.version != new_version - else f"{new_version}-ts{int(time() * 1000)}" - ) - filename = re.sub( - f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", replacement, old_filename - ) + replacement = (str(new_version) if plugin.meta.version != new_version + else f"{new_version}-ts{int(time())}") + filename = re.sub(f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", + replacement, old_filename) else: filename = old_filename.rstrip(".mbp") filename = f"{filename}-v{new_version}.mbp" @@ -117,29 +98,12 @@ async def upload_replacement_plugin( try: await plugin.reload(new_path=path) except MaubotZipImportError as e: - log.exception(f"Error loading updated version of {plugin.meta.id}, rolling back") try: await plugin.reload(new_path=old_path) await plugin.start_instances() except MaubotZipImportError: - log.warning(f"Failed to roll back update of {plugin.meta.id}", exc_info=True) - finally: - ZippedPluginLoader.trash(path, reason="failed_update") + pass return resp.plugin_import_error(str(e), traceback.format_exc()) - try: - await plugin.start_instances() - except Exception as e: - log.exception(f"Error starting {plugin.meta.id} instances after update, rolling back") - try: - await plugin.stop_instances() - await plugin.reload(new_path=old_path) - await plugin.start_instances() - except Exception: - log.warning(f"Failed to roll back update of {plugin.meta.id}", exc_info=True) - finally: - ZippedPluginLoader.trash(path, reason="failed_update") - return resp.plugin_reload_error(str(e), traceback.format_exc()) - - log.debug(f"Successfully updated {plugin.meta.id}, moving old version to trash") + await plugin.start_instances() ZippedPluginLoader.trash(old_path, reason="update") return resp.updated(plugin.to_dict()) diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index 0f22caa..b30d49d 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2019 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,457 +13,271 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from __future__ import annotations - -from typing import TYPE_CHECKING from http import HTTPStatus from aiohttp import web -from asyncpg import PostgresError -import aiosqlite - -if TYPE_CHECKING: - from sqlalchemy.exc import IntegrityError, OperationalError +from sqlalchemy.exc import OperationalError, IntegrityError class _Response: @property def body_not_json(self) -> web.Response: - return web.json_response( - { - "error": "Request body is not JSON", - "errcode": "body_not_json", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Request body is not JSON", + "errcode": "body_not_json", + }, status=HTTPStatus.BAD_REQUEST) @property def plugin_type_required(self) -> web.Response: - return web.json_response( - { - "error": "Plugin type is required when creating plugin instances", - "errcode": "plugin_type_required", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Plugin type is required when creating plugin instances", + "errcode": "plugin_type_required", + }, status=HTTPStatus.BAD_REQUEST) @property def primary_user_required(self) -> web.Response: - return web.json_response( - { - "error": "Primary user is required when creating plugin instances", - "errcode": "primary_user_required", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Primary user is required when creating plugin instances", + "errcode": "primary_user_required", + }, status=HTTPStatus.BAD_REQUEST) @property def bad_client_access_token(self) -> web.Response: - return web.json_response( - { - "error": "Invalid access token", - "errcode": "bad_client_access_token", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Invalid access token", + "errcode": "bad_client_access_token", + }, status=HTTPStatus.BAD_REQUEST) @property def bad_client_access_details(self) -> web.Response: - return web.json_response( - { - "error": "Invalid homeserver or access token", - "errcode": "bad_client_access_details", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Invalid homeserver or access token", + "errcode": "bad_client_access_details" + }, status=HTTPStatus.BAD_REQUEST) @property def bad_client_connection_details(self) -> web.Response: - return web.json_response( - { - "error": "Could not connect to homeserver", - "errcode": "bad_client_connection_details", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Could not connect to homeserver", + "errcode": "bad_client_connection_details" + }, status=HTTPStatus.BAD_REQUEST) def mxid_mismatch(self, found: str) -> web.Response: - return web.json_response( - { - "error": ( - "The Matrix user ID of the client and the user ID of the access token don't " - f"match. Access token is for user {found}" - ), - "errcode": "mxid_mismatch", - }, - status=HTTPStatus.BAD_REQUEST, - ) - - def device_id_mismatch(self, found: str) -> web.Response: - return web.json_response( - { - "error": ( - "The Matrix device ID of the client and the device ID of the access token " - f"don't match. Access token is for device {found}" - ), - "errcode": "mxid_mismatch", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "The Matrix user ID of the client and the user ID of the access token don't " + f"match. Access token is for user {found}", + "errcode": "mxid_mismatch", + }, status=HTTPStatus.BAD_REQUEST) @property def pid_mismatch(self) -> web.Response: - return web.json_response( - { - "error": "The ID in the path does not match the ID of the uploaded plugin", - "errcode": "pid_mismatch", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "The ID in the path does not match the ID of the uploaded plugin", + "errcode": "pid_mismatch", + }, status=HTTPStatus.BAD_REQUEST) @property def username_or_password_missing(self) -> web.Response: - return web.json_response( - { - "error": "Username or password missing", - "errcode": "username_or_password_missing", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Username or password missing", + "errcode": "username_or_password_missing", + }, status=HTTPStatus.BAD_REQUEST) @property def query_missing(self) -> web.Response: - return web.json_response( - { - "error": "Query missing", - "errcode": "query_missing", - }, - status=HTTPStatus.BAD_REQUEST, - ) - - @staticmethod - def sql_error(error: PostgresError | aiosqlite.Error, query: str) -> web.Response: - return web.json_response( - { - "ok": False, - "query": query, - "error": str(error), - "errcode": "sql_error", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Query missing", + "errcode": "query_missing", + }, status=HTTPStatus.BAD_REQUEST) @staticmethod def sql_operational_error(error: OperationalError, query: str) -> web.Response: - return web.json_response( - { - "ok": False, - "query": query, - "error": str(error.orig), - "full_error": str(error), - "errcode": "sql_operational_error", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "ok": False, + "query": query, + "error": str(error.orig), + "full_error": str(error), + "errcode": "sql_operational_error", + }, status=HTTPStatus.BAD_REQUEST) @staticmethod def sql_integrity_error(error: IntegrityError, query: str) -> web.Response: - return web.json_response( - { - "ok": False, - "query": query, - "error": str(error.orig), - "full_error": str(error), - "errcode": "sql_integrity_error", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "ok": False, + "query": query, + "error": str(error.orig), + "full_error": str(error), + "errcode": "sql_integrity_error", + }, status=HTTPStatus.BAD_REQUEST) @property def bad_auth(self) -> web.Response: - return web.json_response( - { - "error": "Invalid username or password", - "errcode": "invalid_auth", - }, - status=HTTPStatus.UNAUTHORIZED, - ) + return web.json_response({ + "error": "Invalid username or password", + "errcode": "invalid_auth", + }, status=HTTPStatus.UNAUTHORIZED) @property def no_token(self) -> web.Response: - return web.json_response( - { - "error": "Authorization token missing", - "errcode": "auth_token_missing", - }, - status=HTTPStatus.UNAUTHORIZED, - ) + return web.json_response({ + "error": "Authorization token missing", + "errcode": "auth_token_missing", + }, status=HTTPStatus.UNAUTHORIZED) @property def invalid_token(self) -> web.Response: - return web.json_response( - { - "error": "Invalid authorization token", - "errcode": "auth_token_invalid", - }, - status=HTTPStatus.UNAUTHORIZED, - ) + return web.json_response({ + "error": "Invalid authorization token", + "errcode": "auth_token_invalid", + }, status=HTTPStatus.UNAUTHORIZED) @property def plugin_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Plugin not found", - "errcode": "plugin_not_found", - }, - status=HTTPStatus.NOT_FOUND, - ) + return web.json_response({ + "error": "Plugin not found", + "errcode": "plugin_not_found", + }, status=HTTPStatus.NOT_FOUND) @property def client_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Client not found", - "errcode": "client_not_found", - }, - status=HTTPStatus.NOT_FOUND, - ) + return web.json_response({ + "error": "Client not found", + "errcode": "client_not_found", + }, status=HTTPStatus.NOT_FOUND) @property def primary_user_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Client for given primary user not found", - "errcode": "primary_user_not_found", - }, - status=HTTPStatus.NOT_FOUND, - ) + return web.json_response({ + "error": "Client for given primary user not found", + "errcode": "primary_user_not_found", + }, status=HTTPStatus.NOT_FOUND) @property def instance_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Plugin instance not found", - "errcode": "instance_not_found", - }, - status=HTTPStatus.NOT_FOUND, - ) + return web.json_response({ + "error": "Plugin instance not found", + "errcode": "instance_not_found", + }, status=HTTPStatus.NOT_FOUND) @property def plugin_type_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Given plugin type not found", - "errcode": "plugin_type_not_found", - }, - status=HTTPStatus.NOT_FOUND, - ) + return web.json_response({ + "error": "Given plugin type not found", + "errcode": "plugin_type_not_found", + }, status=HTTPStatus.NOT_FOUND) @property def path_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Resource not found", - "errcode": "resource_not_found", - }, - status=HTTPStatus.NOT_FOUND, - ) + return web.json_response({ + "error": "Resource not found", + "errcode": "resource_not_found", + }, status=HTTPStatus.NOT_FOUND) @property def server_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Registration target server not found", - "errcode": "server_not_found", - }, - status=HTTPStatus.NOT_FOUND, - ) - - @property - def registration_secret_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Config does not have a registration secret for that server", - "errcode": "registration_secret_not_found", - }, - status=HTTPStatus.NOT_FOUND, - ) - - @property - def registration_no_sso(self) -> web.Response: - return web.json_response( - { - "error": "The register operation is only for registering with a password", - "errcode": "registration_no_sso", - }, - status=HTTPStatus.BAD_REQUEST, - ) - - @property - def sso_not_supported(self) -> web.Response: - return web.json_response( - { - "error": "That server does not seem to support single sign-on", - "errcode": "sso_not_supported", - }, - status=HTTPStatus.FORBIDDEN, - ) + return web.json_response({ + "error": "Registration target server not found", + "errcode": "server_not_found", + }, status=HTTPStatus.NOT_FOUND) @property def plugin_has_no_database(self) -> web.Response: - return web.json_response( - { - "error": "Given plugin does not have a database", - "errcode": "plugin_has_no_database", - } - ) - - @property - def unsupported_plugin_database(self) -> web.Response: - return web.json_response( - { - "error": "The database type is not supported by this API", - "errcode": "unsupported_plugin_database", - } - ) - - @property - def sqlalchemy_not_installed(self) -> web.Response: - return web.json_response( - { - "error": "This plugin requires a legacy database, but SQLAlchemy is not installed", - "errcode": "unsupported_plugin_database", - }, - status=HTTPStatus.NOT_IMPLEMENTED, - ) + return web.json_response({ + "error": "Given plugin does not have a database", + "errcode": "plugin_has_no_database", + }) @property def table_not_found(self) -> web.Response: - return web.json_response( - { - "error": "Given table not found in plugin database", - "errcode": "table_not_found", - } - ) + return web.json_response({ + "error": "Given table not found in plugin database", + "errcode": "table_not_found", + }) @property def method_not_allowed(self) -> web.Response: - return web.json_response( - { - "error": "Method not allowed", - "errcode": "method_not_allowed", - }, - status=HTTPStatus.METHOD_NOT_ALLOWED, - ) + return web.json_response({ + "error": "Method not allowed", + "errcode": "method_not_allowed", + }, status=HTTPStatus.METHOD_NOT_ALLOWED) @property def user_exists(self) -> web.Response: - return web.json_response( - { - "error": "There is already a client with the user ID of that token", - "errcode": "user_exists", - }, - status=HTTPStatus.CONFLICT, - ) + return web.json_response({ + "error": "There is already a client with the user ID of that token", + "errcode": "user_exists", + }, status=HTTPStatus.CONFLICT) @property def plugin_exists(self) -> web.Response: - return web.json_response( - { - "error": "A plugin with the same ID as the uploaded plugin already exists", - "errcode": "plugin_exists", - }, - status=HTTPStatus.CONFLICT, - ) + return web.json_response({ + "error": "A plugin with the same ID as the uploaded plugin already exists", + "errcode": "plugin_exists" + }, status=HTTPStatus.CONFLICT) @property def plugin_in_use(self) -> web.Response: - return web.json_response( - { - "error": "Plugin instances of this type still exist", - "errcode": "plugin_in_use", - }, - status=HTTPStatus.PRECONDITION_FAILED, - ) + return web.json_response({ + "error": "Plugin instances of this type still exist", + "errcode": "plugin_in_use", + }, status=HTTPStatus.PRECONDITION_FAILED) @property def client_in_use(self) -> web.Response: - return web.json_response( - { - "error": "Plugin instances with this client as their primary user still exist", - "errcode": "client_in_use", - }, - status=HTTPStatus.PRECONDITION_FAILED, - ) + return web.json_response({ + "error": "Plugin instances with this client as their primary user still exist", + "errcode": "client_in_use", + }, status=HTTPStatus.PRECONDITION_FAILED) @staticmethod def plugin_import_error(error: str, stacktrace: str) -> web.Response: - return web.json_response( - { - "error": error, - "stacktrace": stacktrace, - "errcode": "plugin_invalid", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": error, + "stacktrace": stacktrace, + "errcode": "plugin_invalid", + }, status=HTTPStatus.BAD_REQUEST) @staticmethod def plugin_reload_error(error: str, stacktrace: str) -> web.Response: - return web.json_response( - { - "error": error, - "stacktrace": stacktrace, - "errcode": "plugin_reload_fail", - }, - status=HTTPStatus.INTERNAL_SERVER_ERROR, - ) + return web.json_response({ + "error": error, + "stacktrace": stacktrace, + "errcode": "plugin_reload_fail", + }, status=HTTPStatus.INTERNAL_SERVER_ERROR) @property def internal_server_error(self) -> web.Response: - return web.json_response( - { - "error": "Internal server error", - "errcode": "internal_server_error", - }, - status=HTTPStatus.INTERNAL_SERVER_ERROR, - ) + return web.json_response({ + "error": "Internal server error", + "errcode": "internal_server_error", + }, status=HTTPStatus.INTERNAL_SERVER_ERROR) @property def invalid_server(self) -> web.Response: - return web.json_response( - { - "error": "Invalid registration server object in maubot configuration", - "errcode": "invalid_server", - }, - status=HTTPStatus.INTERNAL_SERVER_ERROR, - ) + return web.json_response({ + "error": "Invalid registration server object in maubot configuration", + "errcode": "invalid_server", + }, status=HTTPStatus.INTERNAL_SERVER_ERROR) @property def unsupported_plugin_loader(self) -> web.Response: - return web.json_response( - { - "error": "Existing plugin with same ID uses unsupported plugin loader", - "errcode": "unsupported_plugin_loader", - }, - status=HTTPStatus.BAD_REQUEST, - ) + return web.json_response({ + "error": "Existing plugin with same ID uses unsupported plugin loader", + "errcode": "unsupported_plugin_loader", + }, status=HTTPStatus.BAD_REQUEST) @property def not_implemented(self) -> web.Response: - return web.json_response( - { - "error": "Not implemented", - "errcode": "not_implemented", - }, - status=HTTPStatus.NOT_IMPLEMENTED, - ) + return web.json_response({ + "error": "Not implemented", + "errcode": "not_implemented", + }, status=HTTPStatus.NOT_IMPLEMENTED) @property def ok(self) -> web.Response: - return web.json_response( - {"success": True}, - status=HTTPStatus.OK, - ) + return web.json_response({ + "success": True, + }, status=HTTPStatus.OK) @property def deleted(self) -> web.Response: @@ -473,15 +287,19 @@ class _Response: def found(data: dict) -> web.Response: return web.json_response(data, status=HTTPStatus.OK) - @staticmethod - def updated(data: dict, is_login: bool = False) -> web.Response: - return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK) + def updated(self, data: dict) -> web.Response: + return self.found(data) def logged_in(self, token: str) -> web.Response: - return self.found({"token": token}) + return self.found({ + "token": token, + }) def pong(self, user: str, features: dict) -> web.Response: - return self.found({"username": user, "features": features}) + return self.found({ + "username": user, + "features": features, + }) @staticmethod def created(data: dict) -> web.Response: diff --git a/maubot/management/api/spec.yaml b/maubot/management/api/spec.yaml index 8529599..c6f5181 100644 --- a/maubot/management/api/spec.yaml +++ b/maubot/management/api/spec.yaml @@ -366,7 +366,7 @@ paths: schema: $ref: '#/components/schemas/MatrixClient' responses: - 202: + 200: description: Client updated content: application/json: @@ -454,12 +454,6 @@ paths: required: true schema: type: string - - name: update_client - in: query - description: Should maubot store the access details in a Client instead of returning them? - required: false - schema: - type: boolean post: operationId: client_auth_register summary: | @@ -481,29 +475,18 @@ paths: properties: access_token: type: string - example: syt_123_456_789 + example: token_here user_id: type: string example: '@putkiteippi:maunium.net' + home_server: + type: string + example: maunium.net device_id: type: string - example: maubot_F00BAR12 - 201: - description: Client created (when update_client is true) - content: - application/json: - schema: - $ref: '#/components/schemas/MatrixClient' + example: device_id_here 401: $ref: '#/components/responses/Unauthorized' - 409: - description: | - There is already a client with the user ID of that token. - This should usually not happen, because the user ID was just created. - content: - application/json: - schema: - $ref: '#/components/schemas/Error' 500: $ref: '#/components/responses/MatrixServerError' '/client/auth/{server}/login': @@ -514,12 +497,6 @@ paths: required: true schema: type: string - - name: update_client - in: query - description: Should maubot store the access details in a Client instead of returning them? - required: false - schema: - type: boolean post: operationId: client_auth_login summary: Log in to the given Matrix server via the maubot server @@ -542,22 +519,10 @@ paths: example: '@putkiteippi:maunium.net' access_token: type: string - example: syt_123_456_789 + example: token_here device_id: type: string - example: maubot_F00BAR12 - 201: - description: Client created (when update_client is true) - content: - application/json: - schema: - $ref: '#/components/schemas/MatrixClient' - 202: - description: Client updated (when update_client is true) - content: - application/json: - schema: - $ref: '#/components/schemas/MatrixClient' + example: device_id_here 401: $ref: '#/components/responses/Unauthorized' 500: @@ -676,12 +641,6 @@ components: access_token: type: string description: The Matrix access token for this client. - device_id: - type: string - description: The Matrix device ID corresponding to the access token. - fingerprint: - type: string - description: The encryption device fingerprint for verification. enabled: type: boolean example: true diff --git a/maubot/management/frontend/package.json b/maubot/management/frontend/package.json index 294c1f7..c607123 100644 --- a/maubot/management/frontend/package.json +++ b/maubot/management/frontend/package.json @@ -13,15 +13,15 @@ }, "homepage": ".", "dependencies": { + "sass": "^1.34.1", "react": "^17.0.2", "react-ace": "^9.4.1", "react-contextmenu": "^2.14.0", "react-dom": "^17.0.2", - "react-json-tree": "^0.16.1", - "react-router-dom": "^5.3.0", - "react-scripts": "5.0.0", - "react-select": "^5.2.1", - "sass": "^1.34.1" + "react-json-tree": "^0.15.0", + "react-router-dom": "^5.2.0", + "react-scripts": "4.0.3", + "react-select": "^4.3.1" }, "scripts": { "start": "react-scripts start", diff --git a/maubot/management/frontend/public/index.html b/maubot/management/frontend/public/index.html index d3679bf..43255d8 100644 --- a/maubot/management/frontend/public/index.html +++ b/maubot/management/frontend/public/index.html @@ -1,6 +1,6 @@