diff --git a/.editorconfig b/.editorconfig index f1ad5ca..3d6370d 100644 --- a/.editorconfig +++ b/.editorconfig @@ -14,3 +14,6 @@ indent_size = 2 [spec.yaml] indent_size = 2 + +[CHANGELOG.md] +max_line_length = 80 diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml deleted file mode 100644 index 44f4c91..0000000 --- a/.github/FUNDING.yml +++ /dev/null @@ -1 +0,0 @@ -github: tulir diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml new file mode 100644 index 0000000..28d6df2 --- /dev/null +++ b/.github/workflows/python-lint.yml @@ -0,0 +1,26 @@ +name: Python lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.13" + - uses: isort/isort-action@master + with: + sortPaths: "./maubot" + - uses: psf/black@stable + with: + src: "./maubot" + version: "24.10.0" + - name: pre-commit + run: | + pip install pre-commit + pre-commit run -av trailing-whitespace + pre-commit run -av end-of-file-fixer + pre-commit run -av check-yaml + pre-commit run -av check-added-large-files diff --git a/.gitignore b/.gitignore index d475bc1..9fd28ef 100644 --- a/.gitignore +++ b/.gitignore @@ -7,10 +7,13 @@ pip-selfcheck.json *.pyc __pycache__ -*.db +*.db* +*.log /*.yaml !example-config.yaml +!.pre-commit-config.yaml +/start logs/ plugins/ trash/ diff --git a/.gitlab-ci-plugin.yml b/.gitlab-ci-plugin.yml new file mode 100644 index 0000000..45ef06b --- /dev/null +++ b/.gitlab-ci-plugin.yml @@ -0,0 +1,29 @@ +image: dock.mau.dev/maubot/maubot + +stages: +- build + +variables: + PYTHONPATH: /opt/maubot + +build: + stage: build + except: + - tags + script: + - python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_REF_NAME-$CI_COMMIT_SHORT_SHA.mbp + artifacts: + paths: + - "*.mbp" + expire_in: 365 days + +build tags: + stage: build + only: + - tags + script: + - python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_TAG.mbp + artifacts: + paths: + - "*.mbp" + expire_in: never diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 4464992..50d0c15 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -10,7 +10,7 @@ default: - docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY build frontend: - image: node:lts-alpine + image: node:22-alpine stage: build frontend before_script: [] variables: @@ -58,8 +58,8 @@ manifest: script: - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 - docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 - - if [ $CI_COMMIT_BRANCH == "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:latest $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:latest; fi - - if [ $CI_COMMIT_BRANCH != "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME; fi + - if [ "$CI_COMMIT_BRANCH" = "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:latest $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:latest; fi + - if [ "$CI_COMMIT_BRANCH" != "master" ]; then docker manifest create $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 && docker manifest push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME; fi - docker rmi $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-amd64 $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-arm64 @@ -70,5 +70,6 @@ build standalone amd64: script: - docker pull $CI_REGISTRY_IMAGE:standalone || true - docker build --pull --cache-from $CI_REGISTRY_IMAGE:standalone --tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone . -f maubot/standalone/Dockerfile - - if [ $CI_COMMIT_BRANCH == "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:standalone && docker push $CI_REGISTRY_IMAGE:standalone; fi - - if [ $CI_COMMIT_BRANCH != "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone && docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone; fi + - if [ "$CI_COMMIT_BRANCH" = "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:standalone && docker push $CI_REGISTRY_IMAGE:standalone; fi + - if [ "$CI_COMMIT_BRANCH" != "master" ]; then docker tag $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone && docker push $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone; fi + - docker rmi $CI_REGISTRY_IMAGE:standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_NAME-standalone $CI_REGISTRY_IMAGE:$CI_COMMIT_SHA-standalone || true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4a6328e --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + exclude_types: [markdown] + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + language_version: python3 + files: ^maubot/.*\.pyi?$ + - repo: https://github.com/PyCQA/isort + rev: 5.13.2 + hooks: + - id: isort + files: ^maubot/.*\.pyi?$ diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..d9de2b7 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,164 @@ +# v0.5.2 (2025-05-05) + +* Improved tombstone handling to ensure that the tombstone sender has + permissions to invite users to the target room. +* Fixed autojoin and online flags not being applied if set during client + creation (thanks to [@bnsh] in [#258]). +* Fixed plugin web apps not being cleared properly when unloading plugins. + +[@bnsh]: https://github.com/bnsh +[#258]: https://github.com/maubot/maubot/pull/258 + +# v0.5.1 (2025-01-03) + +* Updated Docker image to Alpine 3.21. +* Updated media upload/download endpoints in management frontend + (thanks to [@domrim] in [#253]). +* Fixed plugin web app base path not including a trailing slash + (thanks to [@jkhsjdhjs] in [#240]). +* Changed markdown parsing to cut off plaintext body if necessary to allow + longer formatted messages. +* Updated dependencies to fix Python 3.13 compatibility. + +[@domrim]: https://github.com/domrim +[@jkhsjdhjs]: https://github.com/jkhsjdhjs +[#253]: https://github.com/maubot/maubot/pull/253 +[#240]: https://github.com/maubot/maubot/pull/240 + +# v0.5.0 (2024-08-24) + +* Dropped Python 3.9 support. +* Updated Docker image to Alpine 3.20. +* Updated mautrix-python to 0.20.6 to support authenticated media. +* Removed hard dependency on SQLAlchemy. +* Fixed `main_class` to default to being loaded from the last module instead of + the first if a module name is not explicitly specified. + * This was already the [documented behavior](https://docs.mau.fi/maubot/dev/reference/plugin-metadata.html), + and loading from the first module doesn't make sense due to import order. +* Added simple scheduler utility for running background tasks periodically or + after a certain delay. +* Added testing framework for plugins (thanks to [@abompard] in [#225]). +* Changed `mbc build` to ignore directories declared in `modules` that are + missing an `__init__.py` file. + * Importing the modules at runtime would fail and break the plugin. + To include non-code resources outside modules in the mbp archive, + use `extra_files` instead. + +[#225]: https://github.com/maubot/maubot/issues/225 +[@abompard]: https://github.com/abompard + +# v0.4.2 (2023-09-20) + +* Updated Pillow to 10.0.1. +* Updated Docker image to Alpine 3.18. +* Added logging for errors for /whoami errors when adding new bot accounts. +* Added support for using appservice tokens (including appservice encryption) + in standalone mode. + +# v0.4.1 (2023-03-15) + +* Added `in_thread` parameter to `evt.reply()` and `evt.respond()`. + * By default, responses will go to the thread if the command is in a thread. + * By setting the flag to `True` or `False`, the plugin can force the response + to either be or not be in a thread. +* Fixed static files like the frontend app manifest not being served correctly. +* Fixed `self.loader.meta` not being available to plugins in standalone mode. +* Updated to mautrix-python v0.19.6. + +# v0.4.0 (2023-01-29) + +* Dropped support for using a custom maubot API base path. + * The public URL can still have a path prefix, e.g. when using a reverse + proxy. Both the web interface and `mbc` CLI tool should work fine with + custom prefixes. +* Added `evt.redact()` as a shortcut for `self.client.redact(evt.room_id, evt.event_id)`. +* Fixed `mbc logs` command not working on Python 3.8+. +* Fixed saving plugin configs (broke in v0.3.0). +* Fixed SSO login using the wrong API path (probably broke in v0.3.0). +* Stopped using `cd` in the docker image's `mbc` wrapper to enable using + path-dependent commands like `mbc build` by mounting a directory. +* Updated Docker image to Alpine 3.17. + +# v0.3.1 (2022-03-29) + +* Added encryption dependencies to standalone dockerfile. +* Fixed running without encryption dependencies installed. +* Removed unnecessary imports that broke on SQLAlchemy 1.4+. +* Removed unused alembic dependency. + +# v0.3.0 (2022-03-28) + +* Dropped Python 3.7 support. +* Switched main maubot database to asyncpg/aiosqlite. + * Using the same SQLite database for crypto is now safe again. +* Added support for asyncpg/aiosqlite for plugin databases. + * There are some [basic docs](https://docs.mau.fi/maubot/dev/database/index.html) + and [a simple example](./examples/database) for the new system. + * The old SQLAlchemy system is now deprecated, but will be preserved for + backwards-compatibility until most plugins have updated. +* Started enforcing minimum maubot version in plugins. + * Trying to upload a plugin where the specified version is higher than the + running maubot version will fail. +* Fixed bug where uploading a plugin twice, deleting it and trying to upload + again would fail. +* Updated Docker image to Alpine 3.15. +* Formatted all code using [black](https://github.com/psf/black) + and [isort](https://github.com/PyCQA/isort). + +# v0.2.1 (2021-11-22) + +Docker-only release: added automatic moving of plugin databases from +`/data/plugins/*.db` to `/data/dbs` + +# v0.2.0 (2021-11-20) + +* Moved plugin databases from `/data/plugins` to `/data/dbs` in the docker image. + * v0.2.0 was missing the automatic migration of databases, it was added in v0.2.1. + * If you were using a custom path, you'll have to mount it at `/data/dbs` or + move the databases yourself. +* Removed support for pickle crypto store and added support for SQLite crypto store. + * **If you were previously using the dangerous pickle store for e2ee, you'll + have to re-login with the bots (which can now be done conveniently with + `mbc auth --update-client`).** +* Added SSO support to `mbc auth`. +* Added support for setting device ID for e2ee using the web interface. +* Added e2ee fingerprint field to the web interface. +* Added `--update-client` flag to store access token inside maubot instead of + returning it in `mbc auth`. + * This will also automatically store the device ID now. +* Updated standalone mode. + * Added e2ee and web server support. + * It's now officially supported and [somewhat documented](https://docs.mau.fi/maubot/usage/standalone.html). +* Replaced `_` with `-` when generating command name from function name. +* Replaced unmaintained PyInquirer dependency with questionary + (thanks to [@TinfoilSubmarine] in [#139]). +* Updated Docker image to Alpine 3.14. +* Fixed avatar URLs without the `mxc://` prefix appearing like they work in the + frontend, but not actually working when saved. + +[@TinfoilSubmarine]: https://github.com/TinfoilSubmarine +[#139]: https://github.com/maubot/maubot/pull/139 + +# v0.1.2 (2021-06-12) + +* Added `loader` instance property for plugins to allow reading files within + the plugin archive. +* Added support for reloading `webapp` and `database` meta flags in plugins. + Previously you had to restart maubot instead of just reloading the plugin + when enabling the webapp or database for the first time. +* Added warning log if a plugin uses `@web` decorators without enabling the + `webapp` meta flag. +* Updated frontend to latest React and dependency versions. +* Updated Docker image to Alpine 3.13. +* Fixed registering accounts with Synapse shared secret registration. +* Fixed plugins using `get_event` in encrypted rooms. +* Fixed using the `@command.new` decorator without specifying a name + (i.e. falling back to the function name). + +# v0.1.1 (2021-05-02) + +No changelog. + +# v0.1.0 (2020-10-04) + +Initial tagged release. diff --git a/Dockerfile b/Dockerfile index f855dfb..2c6bad4 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,41 +1,34 @@ -FROM node:12 AS frontend-builder +FROM node:22 AS frontend-builder COPY ./maubot/management/frontend /frontend RUN cd /frontend && yarn --prod && yarn build -FROM alpine:3.12 - -RUN echo $'\ -@edge http://dl-cdn.alpinelinux.org/alpine/edge/main\n\ -@edge http://dl-cdn.alpinelinux.org/alpine/edge/testing\n\ -@edge http://dl-cdn.alpinelinux.org/alpine/edge/community' >> /etc/apk/repositories +FROM alpine:3.21 RUN apk add --no-cache \ python3 py3-pip py3-setuptools py3-wheel \ ca-certificates \ su-exec \ + yq \ py3-aiohttp \ - py3-sqlalchemy \ py3-attrs \ py3-bcrypt \ py3-cffi \ - py3-psycopg2 \ py3-ruamel.yaml \ py3-jinja2 \ py3-click \ py3-packaging \ py3-markdown \ - py3-alembic@edge \ - py3-cssselect@edge \ - py3-commonmark@edge \ + py3-alembic \ + py3-cssselect \ + py3-commonmark \ py3-pygments \ - py3-tz@edge \ - py3-tzlocal@edge \ - py3-regex@edge \ - py3-wcwidth@edge \ + py3-tz \ + py3-regex \ + py3-wcwidth \ # encryption py3-cffi \ - olm-dev \ + py3-olm \ py3-pycryptodome \ py3-unpaddedbase64 \ py3-future \ @@ -45,21 +38,20 @@ RUN apk add --no-cache \ py3-feedparser \ py3-dateutil \ py3-lxml \ - py3-gitlab@edge \ - py3-semver@edge + py3-semver # TODO remove pillow, magic, feedparser, lxml, gitlab and semver when maubot supports installing dependencies COPY requirements.txt /opt/maubot/requirements.txt COPY optional-requirements.txt /opt/maubot/optional-requirements.txt WORKDIR /opt/maubot RUN apk add --virtual .build-deps python3-dev build-base git \ - && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \ - && pip3 install -r requirements.txt -r optional-requirements.txt \ - dateparser langdetect python-gitlab pyquery cchardet \ + && pip3 install --break-system-packages -r requirements.txt -r optional-requirements.txt \ + dateparser langdetect python-gitlab pyquery tzlocal \ && apk del .build-deps # TODO also remove dateparser, langdetect and pyquery when maubot supports installing dependencies COPY . /opt/maubot +RUN cp maubot/example-config.yaml . COPY ./docker/mbc.sh /usr/local/bin/mbc COPY --from=frontend-builder /frontend/build /opt/maubot/frontend ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data diff --git a/Dockerfile.ci b/Dockerfile.ci index 494b7f5..9712a16 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -1,36 +1,30 @@ -FROM alpine:3.12 - -RUN echo $'\ -@edge http://dl-cdn.alpinelinux.org/alpine/edge/main\n\ -@edge http://dl-cdn.alpinelinux.org/alpine/edge/testing\n\ -@edge http://dl-cdn.alpinelinux.org/alpine/edge/community' >> /etc/apk/repositories +FROM alpine:3.21 RUN apk add --no-cache \ python3 py3-pip py3-setuptools py3-wheel \ ca-certificates \ su-exec \ + yq \ py3-aiohttp \ - py3-sqlalchemy \ py3-attrs \ py3-bcrypt \ py3-cffi \ - py3-psycopg2 \ py3-ruamel.yaml \ py3-jinja2 \ py3-click \ py3-packaging \ py3-markdown \ - py3-alembic@edge \ - py3-cssselect@edge \ - py3-commonmark@edge \ + py3-alembic \ +# py3-cssselect \ + py3-commonmark \ py3-pygments \ - py3-tz@edge \ - py3-tzlocal@edge \ - py3-regex@edge \ - py3-wcwidth@edge \ + py3-tz \ +# py3-tzlocal \ + py3-regex \ + py3-wcwidth \ # encryption py3-cffi \ - olm-dev \ + py3-olm \ py3-pycryptodome \ py3-unpaddedbase64 \ py3-future \ @@ -38,22 +32,22 @@ RUN apk add --no-cache \ py3-pillow \ py3-magic \ py3-feedparser \ - py3-lxml \ - py3-gitlab@edge \ - py3-semver@edge + py3-lxml +# py3-gitlab +# py3-semver # TODO remove pillow, magic, feedparser, lxml, gitlab and semver when maubot supports installing dependencies COPY requirements.txt /opt/maubot/requirements.txt COPY optional-requirements.txt /opt/maubot/optional-requirements.txt WORKDIR /opt/maubot RUN apk add --virtual .build-deps python3-dev build-base git \ - && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \ - && pip3 install -r requirements.txt -r optional-requirements.txt \ - dateparser langdetect python-gitlab pyquery cchardet \ + && pip3 install --break-system-packages -r requirements.txt -r optional-requirements.txt \ + dateparser langdetect python-gitlab pyquery semver tzlocal cssselect \ && apk del .build-deps # TODO also remove dateparser, langdetect and pyquery when maubot supports installing dependencies COPY . /opt/maubot +RUN cp /opt/maubot/maubot/example-config.yaml /opt/maubot COPY ./docker/mbc.sh /usr/local/bin/mbc ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data VOLUME /data diff --git a/Dockerfile.local b/Dockerfile.local new file mode 100644 index 0000000..d37e220 --- /dev/null +++ b/Dockerfile.local @@ -0,0 +1,29 @@ +FROM r.batts.cloud/nodejs:22 AS frontend-builder + +COPY ./maubot/management/frontend /frontend +RUN cd /frontend && yarn --prod && yarn build + +FROM r.batts.cloud/debian:bookworm + +RUN apt update && \ + apt install -y --no-install-recommends python3 python3-dev python3-venv python3-semver git gosu yq brotli && \ + apt clean -y && \ + rm -rf /var/lib/apt/lists/* + +COPY requirements.txt /opt/maubot/requirements.txt +COPY optional-requirements.txt /opt/maubot/optional-requirements.txt +WORKDIR /opt/maubot +RUN python3 -m venv /venv \ + && bash -c 'source /venv/bin/activate \ + && pip3 install -r requirements.txt -r optional-requirements.txt \ + dateparser langdetect python-gitlab pyquery tzlocal pyfiglet emoji feedparser brotli' +# TODO also remove pyfiglet, emoji, dateparser, langdetect and pyquery when maubot supports installing dependencies + +COPY . /opt/maubot +RUN cp maubot/example-config.yaml . +COPY ./docker/mbc.sh /usr/local/bin/mbc +COPY --from=frontend-builder /frontend/build /opt/maubot/frontend +ENV UID=1337 GID=1337 XDG_CONFIG_HOME=/data +VOLUME /data + +CMD ["/opt/maubot/docker/run.sh"] diff --git a/MANIFEST.in b/MANIFEST.in index daa36da..d8889bc 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include README.md +include CHANGELOG.md include LICENSE include requirements.txt include optional-requirements.txt diff --git a/README.md b/README.md index 75be206..02a4b6f 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,11 @@ # maubot +![Languages](https://img.shields.io/github/languages/top/maubot/maubot.svg) +[![License](https://img.shields.io/github/license/maubot/maubot.svg)](LICENSE) +[![Release](https://img.shields.io/github/release/maubot/maubot/all.svg)](https://github.com/maubot/maubot/releases) +[![GitLab CI](https://mau.dev/maubot/maubot/badges/master/pipeline.svg)](https://mau.dev/maubot/maubot/container_registry) +[![Code style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) +[![Imports](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) + A plugin-based [Matrix](https://matrix.org) bot system written in Python. ## Documentation @@ -15,41 +22,8 @@ All setup and usage instructions are located on Matrix room: [#maubot:maunium.net](https://matrix.to/#/#maubot:maunium.net) ## Plugins -* [jesaribot](https://github.com/maubot/jesaribot) - A simple bot that replies with an image when you say "jesari". -* [sed](https://github.com/maubot/sed) - A bot to do sed-like replacements. -* [factorial](https://github.com/maubot/factorial) - A bot to calculate unexpected factorials. -* [media](https://github.com/maubot/media) - A bot that replies with the MXC URI of images you send it. -* [dice](https://github.com/maubot/dice) - A combined dice rolling and calculator bot. -* [karma](https://github.com/maubot/karma) - A user karma tracker bot. -* [xkcd](https://github.com/maubot/xkcd) - A bot to view xkcd comics. -* [echo](https://github.com/maubot/echo) - A bot that echoes pings and other stuff. -* [rss](https://github.com/maubot/rss) - A bot that posts RSS feed updates to Matrix. -* [reddit](https://github.com/TomCasavant/RedditMaubot) - A bot that condescendingly corrects a user when they enter an r/subreddit without providing a link to that subreddit -* [giphy](https://github.com/TomCasavant/GiphyMaubot) - A bot that generates a gif (from giphy) given search terms -* [trump](https://github.com/jeffcasavant/MaubotTrumpTweet) - A bot that generates a Trump tweet with the given content -* [poll](https://github.com/TomCasavant/PollMaubot) - A bot that will create a simple poll for users in a room -* [urban](https://github.com/dvdgsng/UrbanMaubot) - A bot that fetches definitions from [Urban Dictionary](https://www.urbandictionary.com/). -* [reminder](https://github.com/maubot/reminder) - A bot to remind you about things. -* [translate](https://github.com/maubot/translate) - A bot to translate words. -* [reactbot](https://github.com/maubot/reactbot) - A bot that responds to messages that match predefined rules. -* [exec](https://github.com/maubot/exec) - A bot that executes code. -* [commitstrip](https://github.com/maubot/commitstrip) - A bot to view CommitStrips. -* [supportportal](https://github.com/maubot/supportportal) - A bot to manage customer support on Matrix. -* [gitlab](https://github.com/maubot/gitlab) - A GitLab client and webhook receiver. -* [github](https://github.com/maubot/github) - A GitHub client and webhook receiver. -* [gitea](https://github.com/saces/maugitea) - A Gitea client and webhook receiver. -* [twilio](https://github.com/jeffcasavant/MaubotTwilio) - Maubot-based SMS bridge -* [tmdb](https://codeberg.org/lomion/tmdb-bot) - A bot that posts information about movies fetched from TheMovieDB.org. -* [tex](https://github.com/maubot/tex) - A bot that renders LaTeX. -* [altalias](https://github.com/maubot/altalias) - A bot that lets users publish alternate aliases in rooms. -* [satwcomic](https://github.com/maubot/satwcomic) - A bot to view SatWComics. -* [songwhip](https://github.com/maubot/songwhip) - A bot to post Songwhip links. -* [invite](https://github.com/williamkray/maubot-invite) - A bot to generate invitation tokens from [matrix-registration](https://github.com/ZerataX/matrix-registration) -* [wolframalpha](https://github.com/ggogel/WolframAlphaMaubot) - A bot that allows requesting information from [WolframAlpha](https://www.wolframalpha.com/). -* [pingcheck](https://edugit.org/nik/maubot-pingcheck) - A bot to ping the echo bot and send rtt to Icinga passive check -* [ticker](https://github.com/williamkray/maubot-ticker) - A bot to return financial data about a stock or cryptocurrency. -* [weather](https://github.com/kellya/maubot-weather) - A bot to get the weather from wttr.in and return a single line of text for the location specified +A list of plugins can be found at [plugins.mau.bot](https://plugins.mau.bot/). -Open a pull request or join the Matrix room linked above to get your plugin listed here +To add your plugin to the list, send a pull request to . -The plugin wishlist lives at https://github.com/maubot/plugin-wishlist/issues +The plugin wishlist lives at . diff --git a/alembic.ini b/alembic.ini deleted file mode 100644 index 0d78e89..0000000 --- a/alembic.ini +++ /dev/null @@ -1,83 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -script_location = alembic - -# template used to generate migration files -# file_template = %%(rev)s_%%(slug)s - -# timezone to use when rendering the date -# within the migration file as well as the filename. -# string value is passed to dateutil.tz.gettz() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the -# "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; this defaults -# to alembic/versions. When using multiple version -# directories, initial revisions must be specified with --version-path -# version_locations = %(here)s/bar %(here)s/bat alembic/versions - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples - -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks=black -# black.type=console_scripts -# black.entrypoint=black -# black.options=-l 79 - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = WARN -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README deleted file mode 100644 index 98e4f9c..0000000 --- a/alembic/README +++ /dev/null @@ -1 +0,0 @@ -Generic single-database configuration. \ No newline at end of file diff --git a/alembic/env.py b/alembic/env.py deleted file mode 100644 index 9946810..0000000 --- a/alembic/env.py +++ /dev/null @@ -1,92 +0,0 @@ -from logging.config import fileConfig - -from sqlalchemy import engine_from_config, pool - -from alembic import context - -import sys -from os.path import abspath, dirname - -sys.path.insert(0, dirname(dirname(abspath(__file__)))) - -from mautrix.util.db import Base -from maubot.config import Config -from maubot import db - -# this is the Alembic Config object, which provides -# access to the values within the .ini file in use. -config = context.config - -maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml") -maubot_config = Config(maubot_config_path, None) -maubot_config.load() -config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%")) - -# Interpret the config file for Python logging. -# This line sets up loggers basically. -fileConfig(config.config_file_name) - -# add your model's MetaData object here -# for 'autogenerate' support -# from myapp import mymodel -# target_metadata = mymodel.Base.metadata -target_metadata = Base.metadata - -# other values from the config, defined by the needs of env.py, -# can be acquired: -# my_important_option = config.get_main_option("my_important_option") -# ... etc. - - -def run_migrations_offline(): - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - render_as_batch=True, - ) - - with context.begin_transaction(): - context.run_migrations() - - -def run_migrations_online(): - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine_from_config( - config.get_section(config.config_ini_section), - prefix="sqlalchemy.", - poolclass=pool.NullPool, - ) - - with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata, - render_as_batch=True, - ) - - with context.begin_transaction(): - context.run_migrations() - - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako deleted file mode 100644 index 2c01563..0000000 --- a/alembic/script.py.mako +++ /dev/null @@ -1,24 +0,0 @@ -"""${message} - -Revision ID: ${up_revision} -Revises: ${down_revision | comma,n} -Create Date: ${create_date} - -""" -from alembic import op -import sqlalchemy as sa -${imports if imports else ""} - -# revision identifiers, used by Alembic. -revision = ${repr(up_revision)} -down_revision = ${repr(down_revision)} -branch_labels = ${repr(branch_labels)} -depends_on = ${repr(depends_on)} - - -def upgrade(): - ${upgrades if upgrades else "pass"} - - -def downgrade(): - ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/4b93300852aa_add_device_id_to_clients.py b/alembic/versions/4b93300852aa_add_device_id_to_clients.py deleted file mode 100644 index efc71cd..0000000 --- a/alembic/versions/4b93300852aa_add_device_id_to_clients.py +++ /dev/null @@ -1,32 +0,0 @@ -"""Add device_id to clients - -Revision ID: 4b93300852aa -Revises: fccd1f95544d -Create Date: 2020-07-11 15:49:38.831459 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = '4b93300852aa' -down_revision = 'fccd1f95544d' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('client', schema=None) as batch_op: - batch_op.add_column(sa.Column('device_id', sa.String(length=255), nullable=True)) - - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table('client', schema=None) as batch_op: - batch_op.drop_column('device_id') - - # ### end Alembic commands ### diff --git a/alembic/versions/90aa88820eab_add_matrix_state_store.py b/alembic/versions/90aa88820eab_add_matrix_state_store.py deleted file mode 100644 index 37a68eb..0000000 --- a/alembic/versions/90aa88820eab_add_matrix_state_store.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Add Matrix state store - -Revision ID: 90aa88820eab -Revises: 4b93300852aa -Create Date: 2020-07-12 01:50:06.215623 - -""" -from alembic import op -import sqlalchemy as sa - -from mautrix.client.state_store.sqlalchemy import SerializableType -from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent - - -# revision identifiers, used by Alembic. -revision = '90aa88820eab' -down_revision = '4b93300852aa' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('mx_room_state', - sa.Column('room_id', sa.String(length=255), nullable=False), - sa.Column('is_encrypted', sa.Boolean(), nullable=True), - sa.Column('has_full_member_list', sa.Boolean(), nullable=True), - sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True), - sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True), - sa.PrimaryKeyConstraint('room_id') - ) - op.create_table('mx_user_profile', - sa.Column('room_id', sa.String(length=255), nullable=False), - sa.Column('user_id', sa.String(length=255), nullable=False), - sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False), - sa.Column('displayname', sa.String(), nullable=True), - sa.Column('avatar_url', sa.String(length=255), nullable=True), - sa.PrimaryKeyConstraint('room_id', 'user_id') - ) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('mx_user_profile') - op.drop_table('mx_room_state') - # ### end Alembic commands ### diff --git a/alembic/versions/d295f8dcfa64_initial_revision.py b/alembic/versions/d295f8dcfa64_initial_revision.py deleted file mode 100644 index ffa502f..0000000 --- a/alembic/versions/d295f8dcfa64_initial_revision.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Initial revision - -Revision ID: d295f8dcfa64 -Revises: -Create Date: 2019-09-27 00:21:02.527915 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'd295f8dcfa64' -down_revision = None -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.create_table('client', - sa.Column('id', sa.String(length=255), nullable=False), - sa.Column('homeserver', sa.String(length=255), nullable=False), - sa.Column('access_token', sa.Text(), nullable=False), - sa.Column('enabled', sa.Boolean(), nullable=False), - sa.Column('next_batch', sa.String(length=255), nullable=False), - sa.Column('filter_id', sa.String(length=255), nullable=False), - sa.Column('sync', sa.Boolean(), nullable=False), - sa.Column('autojoin', sa.Boolean(), nullable=False), - sa.Column('displayname', sa.String(length=255), nullable=False), - sa.Column('avatar_url', sa.String(length=255), nullable=False), - sa.PrimaryKeyConstraint('id') - ) - op.create_table('plugin', - sa.Column('id', sa.String(length=255), nullable=False), - sa.Column('type', sa.String(length=255), nullable=False), - sa.Column('enabled', sa.Boolean(), nullable=False), - sa.Column('primary_user', sa.String(length=255), nullable=False), - sa.Column('config', sa.Text(), nullable=False), - sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'), - sa.PrimaryKeyConstraint('id') - ) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('plugin') - op.drop_table('client') - # ### end Alembic commands ### diff --git a/alembic/versions/fccd1f95544d_add_online_field_to_clients.py b/alembic/versions/fccd1f95544d_add_online_field_to_clients.py deleted file mode 100644 index 1f7eabe..0000000 --- a/alembic/versions/fccd1f95544d_add_online_field_to_clients.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Add online field to clients - -Revision ID: fccd1f95544d -Revises: d295f8dcfa64 -Create Date: 2020-03-06 15:07:50.136644 - -""" -from alembic import op -import sqlalchemy as sa - - -# revision identifiers, used by Alembic. -revision = 'fccd1f95544d' -down_revision = 'd295f8dcfa64' -branch_labels = None -depends_on = None - - -def upgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("client") as batch_op: - batch_op.add_column(sa.Column('online', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true())) - # ### end Alembic commands ### - - -def downgrade(): - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("client") as batch_op: - batch_op.drop_column('online') - # ### end Alembic commands ### diff --git a/dev-requirements.txt b/dev-requirements.txt new file mode 100644 index 0000000..bb8c2a0 --- /dev/null +++ b/dev-requirements.txt @@ -0,0 +1,3 @@ +pre-commit>=2.10.1,<3 +isort>=5.10.1,<6 +black>=24,<25 diff --git a/docker/example-config.yaml b/docker/example-config.yaml deleted file mode 100644 index 192a420..0000000 --- a/docker/example-config.yaml +++ /dev/null @@ -1,108 +0,0 @@ -# The full URI to the database. SQLite and Postgres are fully supported. -# Other DBMSes supported by SQLAlchemy may or may not work. -# Format examples: -# SQLite: sqlite:///filename.db -# Postgres: postgres://username:password@hostname/dbname -database: sqlite:////data/maubot.db - -# Database for encryption data. -crypto_database: - # Type of database. Either "default", "pickle" or "postgres". - # When set to default, using SQLite as the main database will use pickle as the crypto database - # and using Postgres as the main database will use the same one as the crypto database. - # - # When using pickle, individual crypto databases are stored in the pickle_dir directory. - # When using non-default postgres, postgres_uri is used to connect to postgres. - type: default - postgres_uri: postgres://username:password@hostname/dbname - pickle_dir: /data/crypto - -plugin_directories: - # The directory where uploaded new plugins should be stored. - upload: /data/plugins - # The directories from which plugins should be loaded. - # Duplicate plugin IDs will be moved to the trash. - load: - - /data/plugins - # The directory where old plugin versions and conflicting plugins should be moved. - # Set to "delete" to delete files immediately. - trash: /data/trash - # The directory where plugin databases should be stored. - db: /data/plugins - -server: - # The IP and port to listen to. - hostname: 0.0.0.0 - port: 29316 - # Public base URL where the server is visible. - public_url: https://example.com - # The base management API path. - base_path: /_matrix/maubot/v1 - # The base path for the UI. - ui_base_path: /_matrix/maubot - # The base path for plugin endpoints. The instance ID will be appended directly. - plugin_base_path: /_matrix/maubot/plugin/ - # Override path from where to load UI resources. - # Set to false to using pkg_resources to find the path. - override_resource_path: /opt/maubot/frontend - # The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1. - appservice_base_path: /_matrix/app/v1 - # The shared secret to sign API access tokens. - # Set to "generate" to generate and save a new token at startup. - unshared_secret: generate - -# Shared registration secrets to allow registering new users from the management UI -registration_secrets: - example.com: - # Client-server API URL - url: https://example.com - # registration_shared_secret from synapse config - secret: synapse_shared_registration_secret - -# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password -# to prevent normal login. Root is a special user that can't have a password and will always exist. -admins: - root: "" - -# API feature switches. -api_features: - login: true - plugin: true - plugin_upload: true - instance: true - instance_database: true - client: true - client_proxy: true - client_auth: true - dev_open: true - log: true - -# Python logging configuration. -# -# See section 16.7.2 of the Python documentation for more info: -# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema -logging: - version: 1 - formatters: - precise: - format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" - handlers: - file: - class: logging.handlers.RotatingFileHandler - formatter: precise - filename: /var/log/maubot.log - maxBytes: 10485760 - backupCount: 10 - console: - class: logging.StreamHandler - formatter: precise - loggers: - maubot: - level: DEBUG - mautrix: - level: DEBUG - aiohttp: - level: INFO - root: - level: DEBUG - handlers: [file, console] diff --git a/docker/mbc.sh b/docker/mbc.sh index bffbd5e..5bde65a 100755 --- a/docker/mbc.sh +++ b/docker/mbc.sh @@ -1,3 +1,3 @@ #!/bin/sh -cd /opt/maubot +export PYTHONPATH=/opt/maubot python3 -m maubot.cli "$@" diff --git a/docker/run.sh b/docker/run.sh index 96a60b9..1ec95a2 100755 --- a/docker/run.sh +++ b/docker/run.sh @@ -1,21 +1,50 @@ -#!/bin/sh +#!/bin/bash function fixperms { - chown -R $UID:$GID /var/log /data /opt/maubot + chown -R $UID:$GID /var/log /data +} + +function fixdefault { + _value=$(yq e "$1" /data/config.yaml) + if [[ "$_value" == "$2" ]]; then + yq e -i "$1 = "'"'"$3"'"' /data/config.yaml + fi +} + +function fixconfig { + # Change relative default paths to absolute paths in /data + fixdefault '.database' 'sqlite:maubot.db' 'sqlite:/data/maubot.db' + fixdefault '.plugin_directories.upload' './plugins' '/data/plugins' + fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins' + fixdefault '.plugin_directories.trash' './trash' '/data/trash' + fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs' + fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs' + fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log' + # This doesn't need to be configurable + yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml } cd /opt/maubot -mkdir -p /var/log/maubot /data/plugins /data/trash /data/dbs /data/crypto +mkdir -p /var/log/maubot /data/plugins /data/trash /data/dbs if [ ! -f /data/config.yaml ]; then - cp docker/example-config.yaml /data/config.yaml + cp example-config.yaml /data/config.yaml echo "Config file not found. Example config copied to /data/config.yaml" echo "Please modify the config file to your liking and restart the container." fixperms + fixconfig exit fi -alembic -x config=/data/config.yaml upgrade head fixperms -exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml -b docker/example-config.yaml +fixconfig +if ls /data/plugins/*.db > /dev/null 2>&1; then + mv -n /data/plugins/*.db /data/dbs/ +fi + +if [ -f "/venv/bin/activate" ] ; then + exec gosu $UID:$GID bash -c 'source /venv/bin/activate && python3 -m maubot -c /data/config.yaml' +fi +exec su-exec $UID:$GID python3 -m maubot -c /data/config.yaml + diff --git a/example-config.yaml b/example-config.yaml deleted file mode 100644 index 89d3965..0000000 --- a/example-config.yaml +++ /dev/null @@ -1,113 +0,0 @@ -# The full URI to the database. SQLite and Postgres are fully supported. -# Other DBMSes supported by SQLAlchemy may or may not work. -# Format examples: -# SQLite: sqlite:///filename.db -# Postgres: postgres://username:password@hostname/dbname -database: sqlite:///maubot.db - -# Database for encryption data. -crypto_database: - # Type of database. Either "default", "pickle" or "postgres". - # When set to default, using SQLite as the main database will use pickle as the crypto database - # and using Postgres as the main database will use the same one as the crypto database. - # - # When using pickle, individual crypto databases are stored in the pickle_dir directory. - # When using non-default postgres, postgres_uri is used to connect to postgres. - # - # WARNING: The pickle database is dangerous and should not be used in production. - type: default - postgres_uri: postgres://username:password@hostname/dbname - pickle_dir: ./crypto - -plugin_directories: - # The directory where uploaded new plugins should be stored. - upload: ./plugins - # The directories from which plugins should be loaded. - # Duplicate plugin IDs will be moved to the trash. - load: - - ./plugins - # The directory where old plugin versions and conflicting plugins should be moved. - # Set to "delete" to delete files immediately. - trash: ./trash - # The directory where plugin databases should be stored. - db: ./plugins - -server: - # The IP and port to listen to. - hostname: 0.0.0.0 - port: 29316 - # Public base URL where the server is visible. - public_url: https://example.com - # The base management API path. - base_path: /_matrix/maubot/v1 - # The base path for the UI. - ui_base_path: /_matrix/maubot - # The base path for plugin endpoints. The instance ID will be appended directly. - plugin_base_path: /_matrix/maubot/plugin/ - # Override path from where to load UI resources. - # Set to false to using pkg_resources to find the path. - override_resource_path: false - # The base appservice API path. Use / for legacy appservice API and /_matrix/app/v1 for v1. - appservice_base_path: /_matrix/app/v1 - # The shared secret to sign API access tokens. - # Set to "generate" to generate and save a new token at startup. - unshared_secret: generate - -# Shared registration secrets to allow registering new users from the management UI -registration_secrets: - example.com: - # Client-server API URL - url: https://example.com - # registration_shared_secret from synapse config - secret: synapse_shared_registration_secret - -# List of administrator users. Plaintext passwords will be bcrypted on startup. Set empty password -# to prevent normal login. Root is a special user that can't have a password and will always exist. -admins: - root: "" - -# API feature switches. -api_features: - login: true - plugin: true - plugin_upload: true - instance: true - instance_database: true - client: true - client_proxy: true - client_auth: true - dev_open: true - log: true - -# Python logging configuration. -# -# See section 16.7.2 of the Python documentation for more info: -# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema -logging: - version: 1 - formatters: - colored: - (): maubot.lib.color_log.ColorFormatter - format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" - normal: - format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" - handlers: - file: - class: logging.handlers.RotatingFileHandler - formatter: normal - filename: ./maubot.log - maxBytes: 10485760 - backupCount: 10 - console: - class: logging.StreamHandler - formatter: colored - loggers: - maubot: - level: DEBUG - mautrix: - level: DEBUG - aiohttp: - level: INFO - root: - level: DEBUG - handlers: [file, console] diff --git a/examples/LICENSE b/examples/LICENSE index bfdfe68..a4b60f3 100644 --- a/examples/LICENSE +++ b/examples/LICENSE @@ -1,6 +1,6 @@ The MIT License (MIT) -Copyright (c) 2018 Tulir Asokan +Copyright (c) 2022 Tulir Asokan Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/examples/README.md b/examples/README.md index 1837fec..2efabca 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,3 +4,4 @@ All examples are published under the [MIT license](LICENSE). * [Hello World](helloworld/) - Very basic event handling bot that responds "Hello, World!" to all messages. * [Echo bot](https://github.com/maubot/echo) - Basic command handling bot with !echo and !ping commands * [Config example](config/) - Simple example of using a config file +* [Database example](database/) - Simple example of using a database diff --git a/examples/config/base-config.yaml b/examples/config/base-config.yaml index c621847..0f7f8a3 100644 --- a/examples/config/base-config.yaml +++ b/examples/config/base-config.yaml @@ -1,2 +1,5 @@ -# Message to send when user sends !getmessage -message: Default configuration active +# Who is allowed to use the bot? +whitelist: + - "@user:example.com" +# The prefix for the main command without the ! +command_prefix: hello-world diff --git a/examples/config/configurablebot.py b/examples/config/configurablebot.py index 13624be..54b47b6 100644 --- a/examples/config/configurablebot.py +++ b/examples/config/configurablebot.py @@ -1,5 +1,4 @@ from typing import Type - from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper from maubot import Plugin, MessageEvent from maubot.handlers import command @@ -7,19 +6,22 @@ from maubot.handlers import command class Config(BaseProxyConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: - helper.copy("message") + helper.copy("whitelist") + helper.copy("command_prefix") -class DatabaseBot(Plugin): +class ConfigurableBot(Plugin): async def start(self) -> None: - await super().start() self.config.load_and_update() + def get_command_name(self) -> str: + return self.config["command_prefix"] + + @command.new(name=get_command_name) + async def hmm(self, evt: MessageEvent) -> None: + if evt.sender in self.config["whitelist"]: + await evt.reply("You're whitelisted 🎉") + @classmethod def get_config_class(cls) -> Type[BaseProxyConfig]: return Config - - @command.new("getmessage") - async def handler(self, event: MessageEvent) -> None: - if event.sender != self.client.mxid: - await event.reply(self.config["message"]) diff --git a/examples/config/maubot.yaml b/examples/config/maubot.yaml index b049dba..8ab36a9 100644 --- a/examples/config/maubot.yaml +++ b/examples/config/maubot.yaml @@ -1,11 +1,12 @@ maubot: 0.1.0 -id: xyz.maubot.databasebot -version: 1.0.0 +id: xyz.maubot.configurablebot +version: 2.0.0 license: MIT modules: - configurablebot main_class: ConfigurableBot database: false +config: true # Instruct the build tool to include the base config. extra_files: diff --git a/examples/database/maubot.yaml b/examples/database/maubot.yaml new file mode 100644 index 0000000..84f4f69 --- /dev/null +++ b/examples/database/maubot.yaml @@ -0,0 +1,10 @@ +maubot: 0.1.0 +id: xyz.maubot.storagebot +version: 2.0.0 +license: MIT +modules: +- storagebot +main_class: StorageBot +database: true +database_type: asyncpg +config: false diff --git a/examples/database/storagebot.py b/examples/database/storagebot.py new file mode 100644 index 0000000..786bba5 --- /dev/null +++ b/examples/database/storagebot.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from mautrix.util.async_db import UpgradeTable, Connection +from maubot import Plugin, MessageEvent +from maubot.handlers import command + +upgrade_table = UpgradeTable() + + +@upgrade_table.register(description="Initial revision") +async def upgrade_v1(conn: Connection) -> None: + await conn.execute( + """CREATE TABLE stored_data ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + )""" + ) + + +@upgrade_table.register(description="Remember user who added value") +async def upgrade_v2(conn: Connection) -> None: + await conn.execute("ALTER TABLE stored_data ADD COLUMN creator TEXT") + + +class StorageBot(Plugin): + @command.new() + async def storage(self, evt: MessageEvent) -> None: + pass + + @storage.subcommand(help="Store a value") + @command.argument("key") + @command.argument("value", pass_raw=True) + async def put(self, evt: MessageEvent, key: str, value: str) -> None: + q = """ + INSERT INTO stored_data (key, value, creator) VALUES ($1, $2, $3) + ON CONFLICT (key) DO UPDATE SET value=excluded.value, creator=excluded.creator + """ + await self.database.execute(q, key, value, evt.sender) + await evt.reply(f"Inserted {key} into the database") + + @storage.subcommand(help="Get a value from the storage") + @command.argument("key") + async def get(self, evt: MessageEvent, key: str) -> None: + q = "SELECT key, value, creator FROM stored_data WHERE LOWER(key)=LOWER($1)" + row = await self.database.fetchrow(q, key) + if row: + key = row["key"] + value = row["value"] + creator = row["creator"] + await evt.reply(f"`{key}` stored by {creator}:\n\n```\n{value}\n```") + else: + await evt.reply(f"No data stored under `{key}` :(") + + @storage.subcommand(help="List keys in the storage") + @command.argument("prefix", required=False) + async def list(self, evt: MessageEvent, prefix: str | None) -> None: + q = "SELECT key, creator FROM stored_data WHERE key LIKE $1" + rows = await self.database.fetch(q, prefix + "%") + prefix_reply = f" starting with `{prefix}`" if prefix else "" + if len(rows) == 0: + await evt.reply(f"Nothing{prefix_reply} stored in database :(") + else: + formatted_data = "\n".join( + f"* `{row['key']}` stored by {row['creator']}" for row in rows + ) + await evt.reply( + f"Found {len(rows)} keys{prefix_reply} in database:\n\n{formatted_data}" + ) + + @classmethod + def get_db_upgrade_table(cls) -> UpgradeTable | None: + return upgrade_table diff --git a/maubot/__init__.py b/maubot/__init__.py index 9ca0322..5106b46 100644 --- a/maubot/__init__.py +++ b/maubot/__init__.py @@ -1,3 +1,4 @@ +from .__meta__ import __version__ +from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent from .plugin_base import Plugin from .plugin_server import PluginWebApp -from .matrix import MaubotMatrixClient as Client, MaubotMessageEvent as MessageEvent diff --git a/maubot/__main__.py b/maubot/__main__.py index 2ef73f9..c4cba44 100644 --- a/maubot/__main__.py +++ b/maubot/__main__.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,84 +13,171 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import logging.config -import argparse +from __future__ import annotations + import asyncio -import signal -import copy import sys -from .config import Config -from .db import init as init_db -from .server import MaubotServer -from .client import Client, init as init_client_class -from .loader.zip import init as init_zip_loader -from .instance import init as init_plugin_instance_class -from .management.api import init as init_mgmt_api +from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme +from mautrix.util.program import Program + from .__meta__ import __version__ - -parser = argparse.ArgumentParser(description="A plugin-based Matrix bot system.", - prog="python -m maubot") -parser.add_argument("-c", "--config", type=str, default="config.yaml", - metavar="", help="the path to your config file") -parser.add_argument("-b", "--base-config", type=str, default="example-config.yaml", - metavar="", help="the path to the example config " - "(for automatic config updates)") -args = parser.parse_args() - -config = Config(args.config, args.base_config) -config.load() -config.update() - -logging.config.dictConfig(copy.deepcopy(config["logging"])) - -loop = asyncio.get_event_loop() - -stop_log_listener = None -if config["api_features.log"]: - from .management.api.log import init as init_log_listener, stop_all as stop_log_listener - - init_log_listener(loop) - -log = logging.getLogger("maubot.init") -log.info(f"Initializing maubot {__version__}") - -init_zip_loader(config) -db_engine = init_db(config) -clients = init_client_class(config, loop) -management_api = init_mgmt_api(config, loop) -server = MaubotServer(management_api, config, loop) -plugins = init_plugin_instance_class(config, server, loop) - -for plugin in plugins: - plugin.load() - -signal.signal(signal.SIGINT, signal.default_int_handler) -signal.signal(signal.SIGTERM, signal.default_int_handler) - +from .client import Client +from .config import Config +from .db import init as init_db, upgrade_table +from .instance import PluginInstance +from .lib.future_awaitable import FutureAwaitable +from .lib.state_store import PgStateStore +from .loader.zip import init as init_zip_loader +from .management.api import init as init_mgmt_api +from .server import MaubotServer try: - log.info("Starting server") - loop.run_until_complete(server.start()) - if Client.crypto_db: - log.debug("Starting client crypto database") - loop.run_until_complete(Client.crypto_db.start()) - log.info("Starting clients and plugins") - loop.run_until_complete(asyncio.gather(*[client.start() for client in clients])) - log.info("Startup actions complete, running forever") - loop.run_forever() -except KeyboardInterrupt: - log.info("Interrupt received, stopping clients") - loop.run_until_complete(asyncio.gather(*[client.stop() for client in Client.cache.values()])) - if stop_log_listener is not None: - log.debug("Closing websockets") - loop.run_until_complete(stop_log_listener()) - log.debug("Stopping server") - try: - loop.run_until_complete(asyncio.wait_for(server.stop(), 5, loop=loop)) - except asyncio.TimeoutError: - log.warning("Stopping server timed out") - log.debug("Closing event loop") - loop.close() - log.debug("Everything stopped, shutting down") - sys.exit(0) + from mautrix.crypto.store import PgCryptoStore +except ImportError: + PgCryptoStore = None + + +class Maubot(Program): + config: Config + server: MaubotServer + db: Database + crypto_db: Database | None + plugin_postgres_db: PostgresDatabase | None + state_store: PgStateStore + + config_class = Config + module = "maubot" + name = "maubot" + version = __version__ + command = "python -m maubot" + description = "A plugin-based Matrix bot system." + + def prepare_log_websocket(self) -> None: + from .management.api.log import init, stop_all + + init(self.loop) + self.add_shutdown_actions(FutureAwaitable(stop_all)) + + def prepare_arg_parser(self) -> None: + super().prepare_arg_parser() + self.parser.add_argument( + "--ignore-unsupported-database", + action="store_true", + help="Run even if the database schema is too new", + ) + self.parser.add_argument( + "--ignore-foreign-tables", + action="store_true", + help="Run even if the database contains tables from other programs (like Synapse)", + ) + + def prepare_db(self) -> None: + self.db = Database.create( + self.config["database"], + upgrade_table=upgrade_table, + db_args=self.config["database_opts"], + owner_name=self.name, + ignore_foreign_tables=self.args.ignore_foreign_tables, + ) + init_db(self.db) + + if PgCryptoStore: + if self.config["crypto_database"] == "default": + self.crypto_db = self.db + else: + self.crypto_db = Database.create( + self.config["crypto_database"], + upgrade_table=PgCryptoStore.upgrade_table, + ignore_foreign_tables=self.args.ignore_foreign_tables, + ) + else: + self.crypto_db = None + + if self.config["plugin_databases.postgres"] == "default": + if self.db.scheme != Scheme.POSTGRES: + self.log.critical( + 'Using "default" as the postgres plugin database URL is only allowed if ' + "the default database is postgres." + ) + sys.exit(24) + assert isinstance(self.db, PostgresDatabase) + self.plugin_postgres_db = self.db + elif self.config["plugin_databases.postgres"]: + plugin_db = Database.create( + self.config["plugin_databases.postgres"], + db_args={ + **self.config["database_opts"], + **self.config["plugin_databases.postgres_opts"], + }, + ) + if plugin_db.scheme != Scheme.POSTGRES: + self.log.critical("The plugin postgres database URL must be a postgres database") + sys.exit(24) + assert isinstance(plugin_db, PostgresDatabase) + self.plugin_postgres_db = plugin_db + else: + self.plugin_postgres_db = None + + def prepare(self) -> None: + super().prepare() + + if self.config["api_features.log"]: + self.prepare_log_websocket() + + init_zip_loader(self.config) + self.prepare_db() + Client.init_cls(self) + PluginInstance.init_cls(self) + management_api = init_mgmt_api(self.config, self.loop) + self.server = MaubotServer(management_api, self.config, self.loop) + self.state_store = PgStateStore(self.db) + + async def start_db(self) -> None: + self.log.debug("Starting database...") + ignore_unsupported = self.args.ignore_unsupported_database + self.db.upgrade_table.allow_unsupported = ignore_unsupported + self.state_store.upgrade_table.allow_unsupported = ignore_unsupported + try: + await self.db.start() + await self.state_store.upgrade_table.upgrade(self.db) + if self.plugin_postgres_db and self.plugin_postgres_db is not self.db: + await self.plugin_postgres_db.start() + if self.crypto_db: + PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported + if self.crypto_db is not self.db: + await self.crypto_db.start() + else: + await PgCryptoStore.upgrade_table.upgrade(self.db) + except DatabaseException as e: + self.log.critical("Failed to initialize database", exc_info=e) + if e.explanation: + self.log.info(e.explanation) + sys.exit(25) + + async def system_exit(self) -> None: + if hasattr(self, "db"): + self.log.trace("Stopping database due to SystemExit") + await self.db.stop() + + async def start(self) -> None: + await self.start_db() + await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()]) + await asyncio.gather(*[client.start() async for client in Client.all()]) + await super().start() + async for plugin in PluginInstance.all(): + await plugin.load() + await self.server.start() + + async def stop(self) -> None: + self.add_shutdown_actions(*(client.stop() for client in Client.cache.values())) + await super().stop() + self.log.debug("Stopping server") + try: + await asyncio.wait_for(self.server.stop(), 5) + except asyncio.TimeoutError: + self.log.warning("Stopping server timed out") + await self.db.stop() + + +Maubot().run() diff --git a/maubot/__meta__.py b/maubot/__meta__.py index 485f44a..7225152 100644 --- a/maubot/__meta__.py +++ b/maubot/__meta__.py @@ -1 +1 @@ -__version__ = "0.1.1" +__version__ = "0.5.2" diff --git a/maubot/cli/__main__.py b/maubot/cli/__main__.py index 1ffd665..3bdbe0e 100644 --- a/maubot/cli/__main__.py +++ b/maubot/cli/__main__.py @@ -1,2 +1,3 @@ from . import app + app() diff --git a/maubot/cli/base.py b/maubot/cli/base.py index 1aaeec8..b35db53 100644 --- a/maubot/cli/base.py +++ b/maubot/cli/base.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by diff --git a/maubot/cli/cliq/__init__.py b/maubot/cli/cliq/__init__.py index cba14a4..10ede9f 100644 --- a/maubot/cli/cliq/__init__.py +++ b/maubot/cli/cliq/__init__.py @@ -1,2 +1,2 @@ from .cliq import command, option -from .validators import SPDXValidator, VersionValidator, PathValidator +from .validators import PathValidator, SPDXValidator, VersionValidator diff --git a/maubot/cli/cliq/cliq.py b/maubot/cli/cliq/cliq.py index 973587a..2883441 100644 --- a/maubot/cli/cliq/cliq.py +++ b/maubot/cli/cliq/cliq.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,20 +13,55 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Any, Callable, Union, Optional -import functools +from __future__ import annotations +from typing import Any, Callable +import asyncio +import functools +import inspect +import traceback + +from colorama import Fore from prompt_toolkit.validation import Validator -from PyInquirer import prompt +from questionary import prompt +import aiohttp import click from ..base import app -from .validators import Required, ClickValidator +from ..config import get_token +from .validators import ClickValidator, Required + + +def with_http(func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + async with aiohttp.ClientSession() as sess: + try: + return await func(*args, sess=sess, **kwargs) + except aiohttp.ClientError as e: + print(f"{Fore.RED}Connection error: {e}{Fore.RESET}") + + return wrapper + + +def with_authenticated_http(func): + @functools.wraps(func) + async def wrapper(*args, server: str, **kwargs): + server, token = get_token(server) + if not token: + return + async with aiohttp.ClientSession(headers={"Authorization": f"Bearer {token}"}) as sess: + try: + return await func(*args, sess=sess, server=server, **kwargs) + except aiohttp.ClientError as e: + print(f"{Fore.RED}Connection error: {e}{Fore.RESET}") + + return wrapper def command(help: str) -> Callable[[Callable], Callable]: def decorator(func) -> Callable: - questions = func.__inquirer_questions__.copy() + questions = getattr(func, "__inquirer_questions__", {}).copy() @functools.wraps(func) def wrapper(*args, **kwargs): @@ -39,6 +74,11 @@ def command(help: str) -> Callable[[Callable], Callable]: required_unless = questions[key].pop("required_unless") if isinstance(required_unless, str) and kwargs[required_unless]: questions.pop(key) + elif isinstance(required_unless, list): + for v in required_unless: + if kwargs[v]: + questions.pop(key) + break elif isinstance(required_unless, dict): for k, v in required_unless.items(): if kwargs.get(v, object()) == v: @@ -48,18 +88,25 @@ def command(help: str) -> Callable[[Callable], Callable]: pass question_list = list(questions.values()) question_list.reverse() - resp = prompt(question_list, keyboard_interrupt_msg="Aborted!") + resp = prompt(question_list, kbi_msg="Aborted!") if not resp and question_list: return kwargs = {**kwargs, **resp} - func(*args, **kwargs) + + try: + res = func(*args, **kwargs) + if inspect.isawaitable(res): + asyncio.run(res) + except Exception: + print(Fore.RED + "Fatal error running command" + Fore.RESET) + traceback.print_exc() return app.command(help=help)(wrapper) return decorator -def yesno(val: str) -> Optional[bool]: +def yesno(val: str) -> bool | None: if not val: return None elif isinstance(val, bool): @@ -73,14 +120,25 @@ def yesno(val: str) -> Optional[bool]: yesno.__name__ = "yes/no" -def option(short: str, long: str, message: str = None, help: str = None, - click_type: Union[str, Callable[[str], Any]] = None, inq_type: str = None, - validator: Validator = None, required: bool = False, default: str = None, - is_flag: bool = False, prompt: bool = True, required_unless: str = None - ) -> Callable[[Callable], Callable]: +def option( + short: str, + long: str, + message: str = None, + help: str = None, + click_type: str | Callable[[str], Any] = None, + inq_type: str = None, + validator: type[Validator] = None, + required: bool = False, + default: str | bool | None = None, + is_flag: bool = False, + prompt: bool = True, + required_unless: str | list | dict = None, +) -> Callable[[Callable], Callable]: if not message: message = long[2].upper() + long[3:] - click_type = validator.click_type if isinstance(validator, ClickValidator) else click_type + + if isinstance(validator, type) and issubclass(validator, ClickValidator): + click_type = validator.click_type if is_flag: click_type = yesno @@ -91,9 +149,9 @@ def option(short: str, long: str, message: str = None, help: str = None, if not hasattr(func, "__inquirer_questions__"): func.__inquirer_questions__ = {} q = { - "type": (inq_type if isinstance(inq_type, str) - else ("input" if not is_flag - else "confirm")), + "type": ( + inq_type if isinstance(inq_type, str) else ("input" if not is_flag else "confirm") + ), "name": long[2:], "message": message, } @@ -102,9 +160,9 @@ def option(short: str, long: str, message: str = None, help: str = None, if default is not None: q["default"] = default if required or required_unless is not None: - q["validator"] = Required(validator) + q["validate"] = Required(validator) elif validator: - q["validator"] = validator + q["validate"] = validator func.__inquirer_questions__[long[2:]] = q return func diff --git a/maubot/cli/cliq/validators.py b/maubot/cli/cliq/validators.py index 9a57914..46d3c92 100644 --- a/maubot/cli/cliq/validators.py +++ b/maubot/cli/cliq/validators.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -16,9 +16,9 @@ from typing import Callable import os -from packaging.version import Version, InvalidVersion -from prompt_toolkit.validation import Validator, ValidationError +from packaging.version import InvalidVersion, Version from prompt_toolkit.document import Document +from prompt_toolkit.validation import ValidationError, Validator import click from ..util import spdx as spdxlib diff --git a/maubot/cli/commands/__init__.py b/maubot/cli/commands/__init__.py index c535234..145646b 100644 --- a/maubot/cli/commands/__init__.py +++ b/maubot/cli/commands/__init__.py @@ -1 +1 @@ -from . import upload, build, login, init, logs, auth +from . import auth, build, init, login, logs, upload diff --git a/maubot/cli/commands/auth.py b/maubot/cli/commands/auth.py index d8cb3cb..64b1dc7 100644 --- a/maubot/cli/commands/auth.py +++ b/maubot/cli/commands/auth.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,75 +13,154 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from urllib.parse import quote -from urllib.request import urlopen, Request -from urllib.error import HTTPError -import functools import json +import webbrowser from colorama import Fore +from yarl import URL +import aiohttp import click -from ..config import get_token from ..cliq import cliq history_count: int = 10 -enc = functools.partial(quote, safe="") - friendly_errors = { - "server_not_found": "Registration target server not found.\n\n" - "To log in or register through maubot, you must add the server to the\n" - "registration_secrets section in the config. If you only want to log in,\n" - "leave the `secret` field empty." + "server_not_found": ( + "Registration target server not found.\n\n" + "To log in or register through maubot, you must add the server to the\n" + "homeservers section in the config. If you only want to log in,\n" + "leave the `secret` field empty." + ), + "registration_no_sso": ( + "The register operation is only for registering with a password.\n\n" + "To register with SSO, simply leave out the --register flag." + ), } +async def list_servers(server: str, sess: aiohttp.ClientSession) -> None: + url = URL(server) / "_matrix/maubot/v1/client/auth/servers" + async with sess.get(url) as resp: + data = await resp.json() + print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}") + for server in data.keys(): + print(f"* {Fore.CYAN}{server}{Fore.RESET}") + + @cliq.command(help="Log into a Matrix account via the Maubot server") @cliq.option("-h", "--homeserver", help="The homeserver to log into", required_unless="list") -@cliq.option("-u", "--username", help="The username to log in with", required_unless="list") -@cliq.option("-p", "--password", help="The password to log in with", inq_type="password", - required_unless="list") -@cliq.option("-s", "--server", help="The maubot instance to log in through", default="", - required=False, prompt=False) -@click.option("-r", "--register", help="Register instead of logging in", is_flag=True, - default=False) +@cliq.option( + "-u", "--username", help="The username to log in with", required_unless=["list", "sso"] +) +@cliq.option( + "-p", + "--password", + help="The password to log in with", + inq_type="password", + required_unless=["list", "sso"], +) +@cliq.option( + "-s", + "--server", + help="The maubot instance to log in through", + default="", + required=False, + prompt=False, +) +@click.option( + "-r", "--register", help="Register instead of logging in", is_flag=True, default=False +) +@click.option( + "-c", + "--update-client", + help="Instead of returning the access token, create or update a client in maubot using it", + is_flag=True, + default=False, +) @click.option("-l", "--list", help="List available homeservers", is_flag=True, default=False) -def auth(homeserver: str, username: str, password: str, server: str, register: bool, list: bool - ) -> None: - server, token = get_token(server) - if not token: - return - headers = {"Authorization": f"Bearer {token}"} +@click.option( + "-o", "--sso", help="Use single sign-on instead of password login", is_flag=True, default=False +) +@click.option( + "-n", + "--device-name", + help="The initial e2ee device displayname (only for login)", + default="Maubot", + required=False, +) +@cliq.with_authenticated_http +async def auth( + homeserver: str, + username: str, + password: str, + server: str, + register: bool, + list: bool, + update_client: bool, + device_name: str, + sso: bool, + sess: aiohttp.ClientSession, +) -> None: if list: - url = f"{server}/_matrix/maubot/v1/client/auth/servers" - with urlopen(Request(url, headers=headers)) as resp_data: - resp = json.load(resp_data) - print(f"{Fore.GREEN}Available Matrix servers for registration and login:{Fore.RESET}") - for server in resp.keys(): - print(f"* {Fore.CYAN}{server}{Fore.RESET}") - return + await list_servers(server, sess) + return endpoint = "register" if register else "login" - headers["Content-Type"] = "application/json" - url = f"{server}/_matrix/maubot/v1/client/auth/{enc(homeserver)}/{endpoint}" - req = Request(url, headers=headers, - data=json.dumps({ - "username": username, - "password": password, - }).encode("utf-8")) + url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / endpoint + if update_client: + url = url.update_query({"update_client": "true"}) + if sso: + url = url.update_query({"sso": "true"}) + req_data = {"device_name": device_name} + else: + req_data = {"username": username, "password": password, "device_name": device_name} + + async with sess.post(url, json=req_data) as resp: + if not 200 <= resp.status < 300: + await print_error(resp, is_register=register) + elif sso: + await wait_sso(resp, sess, server, homeserver) + else: + await print_response(resp, is_register=register) + + +async def wait_sso( + resp: aiohttp.ClientResponse, sess: aiohttp.ClientSession, server: str, homeserver: str +) -> None: + data = await resp.json() + sso_url, reg_id = data["sso_url"], data["id"] + print(f"{Fore.GREEN}Opening {Fore.CYAN}{sso_url}{Fore.RESET}") + webbrowser.open(sso_url, autoraise=True) + print(f"{Fore.GREEN}Waiting for login token...{Fore.RESET}") + wait_url = URL(server) / "_matrix/maubot/v1/client/auth" / homeserver / "sso" / reg_id / "wait" + async with sess.post(wait_url, json={}) as resp: + await print_response(resp, is_register=False) + + +async def print_response(resp: aiohttp.ClientResponse, is_register: bool) -> None: + if resp.status == 200: + data = await resp.json() + action = "registered" if is_register else "logged in as" + print(f"{Fore.GREEN}Successfully {action} {Fore.CYAN}{data['user_id']}{Fore.GREEN}.") + print(f"{Fore.GREEN}Access token: {Fore.CYAN}{data['access_token']}{Fore.RESET}") + print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{data['device_id']}{Fore.RESET}") + elif resp.status in (201, 202): + data = await resp.json() + action = "created" if resp.status == 201 else "updated" + print( + f"{Fore.GREEN}Successfully {action} client for " + f"{Fore.CYAN}{data['id']}{Fore.GREEN} / " + f"{Fore.CYAN}{data['device_id']}{Fore.GREEN}.{Fore.RESET}" + ) + else: + await print_error(resp, is_register) + + +async def print_error(resp: aiohttp.ClientResponse, is_register: bool) -> None: try: - with urlopen(req) as resp_data: - resp = json.load(resp_data) - action = "registered" if register else "logged in as" - print(f"{Fore.GREEN}Successfully {action} " - f"{Fore.CYAN}{resp['user_id']}{Fore.GREEN}.") - print(f"{Fore.GREEN}Access token: {Fore.CYAN}{resp['access_token']}{Fore.RESET}") - print(f"{Fore.GREEN}Device ID: {Fore.CYAN}{resp['device_id']}{Fore.RESET}") - except HTTPError as e: - try: - err_data = json.load(e) - error = friendly_errors.get(err_data["errcode"], err_data["error"]) - except (json.JSONDecodeError, KeyError): - error = str(e) - action = "register" if register else "log in" - print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") + err_data = await resp.json() + error = friendly_errors.get(err_data["errcode"], err_data["error"]) + except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): + error = await resp.text() + action = "register" if is_register else "log in" + print(f"{Fore.RED}Failed to {action}: {error}{Fore.RESET}") diff --git a/maubot/cli/commands/build.py b/maubot/cli/commands/build.py index 9e3724c..39eca53 100644 --- a/maubot/cli/commands/build.py +++ b/maubot/cli/commands/build.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,21 +13,27 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Union, IO -from io import BytesIO -import zipfile -import os +from __future__ import annotations -from ruamel.yaml import YAML, YAMLError +from typing import IO +from io import BytesIO +import asyncio +import glob +import os +import zipfile + +from aiohttp import ClientSession from colorama import Fore -from PyInquirer import prompt +from questionary import prompt +from ruamel.yaml import YAML, YAMLError import click from mautrix.types import SerializerError from ...loader import PluginMeta -from ..cliq.validators import PathValidator from ..base import app +from ..cliq import cliq +from ..cliq.validators import PathValidator from ..config import get_token from .upload import upload_file @@ -40,7 +46,7 @@ def zipdir(zip, dir): zip.write(os.path.join(root, file)) -def read_meta(path: str) -> Optional[PluginMeta]: +def read_meta(path: str) -> PluginMeta | None: try: with open(os.path.join(path, "maubot.yaml")) as meta_file: try: @@ -61,7 +67,7 @@ def read_meta(path: str) -> Optional[PluginMeta]: return meta -def read_output_path(output: str, meta: PluginMeta) -> Optional[str]: +def read_output_path(output: str, meta: PluginMeta) -> str | None: directory = os.getcwd() filename = f"{meta.id}-v{meta.version}.mbp" if not output: @@ -69,18 +75,15 @@ def read_output_path(output: str, meta: PluginMeta) -> Optional[str]: elif os.path.isdir(output): output = os.path.join(output, filename) elif os.path.exists(output): - override = prompt({ - "type": "confirm", - "name": "override", - "message": f"{output} exists, override?" - })["override"] + q = [{"type": "confirm", "name": "override", "message": f"{output} exists, override?"}] + override = prompt(q)["override"] if not override: return None os.remove(output) return os.path.abspath(output) -def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None: +def write_plugin(meta: PluginMeta, output: str | IO) -> None: with zipfile.ZipFile(output, "w") as zip: meta_dump = BytesIO() yaml.dump(meta.serialize(), meta_dump) @@ -90,33 +93,47 @@ def write_plugin(meta: PluginMeta, output: Union[str, IO]) -> None: if os.path.isfile(f"{module}.py"): zip.write(f"{module}.py") elif module is not None and os.path.isdir(module): - zipdir(zip, module) + if os.path.isfile(f"{module}/__init__.py"): + zipdir(zip, module) + else: + print( + Fore.YELLOW + + f"Module {module} is missing __init__.py, skipping" + + Fore.RESET + ) else: print(Fore.YELLOW + f"Module {module} not found, skipping" + Fore.RESET) - - for file in meta.extra_files: - zip.write(file) + for pattern in meta.extra_files: + for file in glob.iglob(pattern): + zip.write(file) -def upload_plugin(output: Union[str, IO], server: str) -> None: +@cliq.with_authenticated_http +async def upload_plugin(output: str | IO, *, server: str, sess: ClientSession) -> None: server, token = get_token(server) if not token: return if isinstance(output, str): with open(output, "rb") as file: - upload_file(file, server, token) + await upload_file(sess, file, server) else: - upload_file(output, server, token) + await upload_file(sess, output, server) -@app.command(short_help="Build a maubot plugin", - help="Build a maubot plugin. First parameter is the path to root of the plugin " - "to build. You can also use --output to specify output file.") +@app.command( + short_help="Build a maubot plugin", + help=( + "Build a maubot plugin. First parameter is the path to root of the plugin " + "to build. You can also use --output to specify output file." + ), +) @click.argument("path", default=os.getcwd()) -@click.option("-o", "--output", help="Path to output built plugin to", - type=PathValidator.click_type) -@click.option("-u", "--upload", help="Upload plugin to server after building", is_flag=True, - default=False) +@click.option( + "-o", "--output", help="Path to output built plugin to", type=PathValidator.click_type +) +@click.option( + "-u", "--upload", help="Upload plugin to server after building", is_flag=True, default=False +) @click.option("-s", "--server", help="Server to upload built plugin to") def build(path: str, output: str, upload: bool, server: str) -> None: meta = read_meta(path) @@ -135,4 +152,4 @@ def build(path: str, output: str, upload: bool, server: str) -> None: else: output.seek(0) if upload: - upload_plugin(output, server) + asyncio.run(upload_plugin(output, server=server)) diff --git a/maubot/cli/commands/init.py b/maubot/cli/commands/init.py index 7372a2d..d24def9 100644 --- a/maubot/cli/commands/init.py +++ b/maubot/cli/commands/init.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,11 +13,11 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from pkg_resources import resource_string import os -from packaging.version import Version from jinja2 import Template +from packaging.version import Version +from pkg_resources import resource_string from .. import cliq from ..cliq import SPDXValidator, VersionValidator @@ -40,26 +40,55 @@ def load_templates(): @cliq.command(help="Initialize a new maubot plugin") -@cliq.option("-n", "--name", help="The name of the project", required=True, - default=os.path.basename(os.getcwd())) -@cliq.option("-i", "--id", message="ID", required=True, - help="The maubot plugin ID (Java package name format)") -@cliq.option("-v", "--version", help="Initial version for project (PEP-440 format)", - default="0.1.0", validator=VersionValidator, required=True) -@cliq.option("-l", "--license", validator=SPDXValidator, default="AGPL-3.0-or-later", - help="The license for the project (SPDX identifier)", required=False) -@cliq.option("-c", "--config", message="Should the plugin include a config?", - help="Include a config in the plugin stub", default=False, is_flag=True) +@cliq.option( + "-n", + "--name", + help="The name of the project", + required=True, + default=os.path.basename(os.getcwd()), +) +@cliq.option( + "-i", + "--id", + message="ID", + required=True, + help="The maubot plugin ID (Java package name format)", +) +@cliq.option( + "-v", + "--version", + help="Initial version for project (PEP-440 format)", + default="0.1.0", + validator=VersionValidator, + required=True, +) +@cliq.option( + "-l", + "--license", + validator=SPDXValidator, + default="AGPL-3.0-or-later", + help="The license for the project (SPDX identifier)", + required=False, +) +@cliq.option( + "-c", + "--config", + message="Should the plugin include a config?", + help="Include a config in the plugin stub", + default=False, + is_flag=True, +) def init(name: str, id: str, version: Version, license: str, config: bool) -> None: load_templates() main_class = name[0].upper() + name[1:] - meta = meta_template.render(id=id, version=str(version), license=license, config=config, - main_class=main_class) + meta = meta_template.render( + id=id, version=str(version), license=license, config=config, main_class=main_class + ) with open("maubot.yaml", "w") as file: file.write(meta) if license: with open("LICENSE", "w") as file: - file.write(spdx.get(license)["text"]) + file.write(spdx.get(license)["licenseText"]) if not os.path.isdir(name): os.mkdir(name) mod = mod_template.render(config=config, name=main_class) diff --git a/maubot/cli/commands/login.py b/maubot/cli/commands/login.py index fdf71b3..8aac0f5 100644 --- a/maubot/cli/commands/login.py +++ b/maubot/cli/commands/login.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,32 +13,55 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from urllib.request import urlopen -from urllib.error import HTTPError import json import os from colorama import Fore +from yarl import URL +import aiohttp -from ..config import save_config, config from ..cliq import cliq +from ..config import config, save_config @cliq.command(help="Log in to a Maubot instance") -@cliq.option("-u", "--username", help="The username of your account", default=os.environ.get("USER", None), required=True) -@cliq.option("-p", "--password", help="The password to your account", inq_type="password", required=True) -@cliq.option("-s", "--server", help="The server to log in to", default="http://localhost:29316", required=True) -@cliq.option("-a", "--alias", help="Alias to reference the server without typing the full URL", default="", required=False) -def login(server, username, password, alias) -> None: +@cliq.option( + "-u", + "--username", + help="The username of your account", + default=os.environ.get("USER", None), + required=True, +) +@cliq.option( + "-p", "--password", help="The password to your account", inq_type="password", required=True +) +@cliq.option( + "-s", + "--server", + help="The server to log in to", + default="http://localhost:29316", + required=True, +) +@cliq.option( + "-a", + "--alias", + help="Alias to reference the server without typing the full URL", + default="", + required=False, +) +@cliq.with_http +async def login( + server: str, username: str, password: str, alias: str, sess: aiohttp.ClientSession +) -> None: data = { "username": username, "password": password, } - try: - with urlopen(f"{server}/_matrix/maubot/v1/auth/login", - data=json.dumps(data).encode("utf-8")) as resp_data: - resp = json.load(resp_data) - config["servers"][server] = resp["token"] + url = URL(server) / "_matrix/maubot/v1/auth/login" + async with sess.post(url, json=data) as resp: + if resp.status == 200: + data = await resp.json() + config["servers"][server] = data["token"] if not config["default_server"]: print(Fore.CYAN, "Setting", server, "as the default server") config["default_server"] = server @@ -46,9 +69,9 @@ def login(server, username, password, alias) -> None: config["aliases"][alias] = server save_config() print(Fore.GREEN + "Logged in successfully") - except HTTPError as e: - try: - err = json.load(e) - except json.JSONDecodeError: - err = {} - print(Fore.RED + err.get("error", str(e)) + Fore.RESET) + else: + try: + err = (await resp.json())["error"] + except (json.JSONDecodeError, KeyError): + err = await resp.text() + print(Fore.RED + err + Fore.RESET) diff --git a/maubot/cli/commands/logs.py b/maubot/cli/commands/logs.py index 8d0a578..e0ed07d 100644 --- a/maubot/cli/commands/logs.py +++ b/maubot/cli/commands/logs.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -16,14 +16,14 @@ from datetime import datetime import asyncio +from aiohttp import ClientSession, WSMessage, WSMsgType from colorama import Fore -from aiohttp import WSMsgType, WSMessage, ClientSession import click from mautrix.types import Obj -from ..config import get_token from ..base import app +from ..config import get_token history_count: int = 10 @@ -38,19 +38,13 @@ def logs(server: str, tail: int) -> None: global history_count history_count = tail loop = asyncio.get_event_loop() - future = asyncio.ensure_future(view_logs(server, token), loop=loop) - try: - loop.run_until_complete(future) - except KeyboardInterrupt: - future.cancel() - loop.run_until_complete(future) - loop.close() + loop.run_until_complete(view_logs(server, token)) def parsedate(entry: Obj) -> None: i = entry.time.index("+") i = entry.time.index(":", i) - entry.time = entry.time[:i] + entry.time[i + 1:] + entry.time = entry.time[:i] + entry.time[i + 1 :] entry.time = datetime.strptime(entry.time, "%Y-%m-%dT%H:%M:%S.%f%z") @@ -66,13 +60,16 @@ levelcolors = { def print_entry(entry: dict) -> None: entry = Obj(**entry) parsedate(entry) - print("{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}" - .format(date=entry.time.strftime("%Y-%m-%d %H:%M:%S"), - level=entry.levelname, - levelcolor=levelcolors.get(entry.levelname, ""), - resetcolor=Fore.RESET, - logger=entry.name, - message=entry.msg)) + print( + "{levelcolor}[{date}] [{level}@{logger}] {message}{resetcolor}".format( + date=entry.time.strftime("%Y-%m-%d %H:%M:%S"), + level=entry.levelname, + levelcolor=levelcolors.get(entry.levelname, ""), + resetcolor=Fore.RESET, + logger=entry.name, + message=entry.msg, + ) + ) if entry.exc_info: print(entry.exc_info) diff --git a/maubot/cli/commands/upload.py b/maubot/cli/commands/upload.py index cb5b4b5..3c2cf1e 100644 --- a/maubot/cli/commands/upload.py +++ b/maubot/cli/commands/upload.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,45 +13,46 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from urllib.request import urlopen, Request -from urllib.error import HTTPError from typing import IO import json from colorama import Fore +from yarl import URL +import aiohttp import click -from ..base import app -from ..config import get_default_server, get_token +from ..cliq import cliq class UploadError(Exception): pass -@app.command(help="Upload a maubot plugin") +@cliq.command(help="Upload a maubot plugin") @click.argument("path") @click.option("-s", "--server", help="The maubot instance to upload the plugin to") -def upload(path: str, server: str) -> None: - server, token = get_token(server) - if not token: - return +@cliq.with_authenticated_http +async def upload(path: str, server: str, sess: aiohttp.ClientSession) -> None: with open(path, "rb") as file: - upload_file(file, server, token) + await upload_file(sess, file, server) -def upload_file(file: IO, server: str, token: str) -> None: - req = Request(f"{server}/_matrix/maubot/v1/plugins/upload?allow_override=true", data=file, - headers={"Authorization": f"Bearer {token}", "Content-Type": "application/zip"}) - try: - with urlopen(req) as resp_data: - resp = json.load(resp_data) - print(f"{Fore.GREEN}Plugin {Fore.CYAN}{resp['id']} v{resp['version']}{Fore.GREEN} " - f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}") - except HTTPError as e: - try: - err = json.load(e) - except json.JSONDecodeError: - err = {} - print(err.get("stacktrace", "")) - print(Fore.RED + "Failed to upload plugin: " + err.get("error", str(e)) + Fore.RESET) +async def upload_file(sess: aiohttp.ClientSession, file: IO, server: str) -> None: + url = (URL(server) / "_matrix/maubot/v1/plugins/upload").with_query({"allow_override": "true"}) + headers = {"Content-Type": "application/zip"} + async with sess.post(url, data=file, headers=headers) as resp: + if resp.status in (200, 201): + data = await resp.json() + print( + f"{Fore.GREEN}Plugin {Fore.CYAN}{data['id']} v{data['version']}{Fore.GREEN} " + f"uploaded to {Fore.CYAN}{server}{Fore.GREEN} successfully.{Fore.RESET}" + ) + else: + try: + err = await resp.json() + if "stacktrace" in err: + print(err["stacktrace"]) + err = err["error"] + except (aiohttp.ContentTypeError, json.JSONDecodeError, KeyError): + err = await resp.text() + print(f"{Fore.RED}Failed to upload plugin: {err}{Fore.RESET}") diff --git a/maubot/cli/config.py b/maubot/cli/config.py index 550c326..5fdc4ea 100644 --- a/maubot/cli/config.py +++ b/maubot/cli/config.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,13 +13,15 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, Optional, Dict, Any +from __future__ import annotations + +from typing import Any import json import os from colorama import Fore -config: Dict[str, Any] = { +config: dict[str, Any] = { "servers": {}, "aliases": {}, "default_server": None, @@ -27,18 +29,19 @@ config: Dict[str, Any] = { configdir = os.environ.get("XDG_CONFIG_HOME", os.path.join(os.environ.get("HOME"), ".config")) -def get_default_server() -> Tuple[Optional[str], Optional[str]]: +def get_default_server() -> tuple[str | None, str | None]: try: - server: Optional[str] = config["default_server"] + server: str < None = config["default_server"] except KeyError: server = None if server is None: print(f"{Fore.RED}Default server not configured.{Fore.RESET}") + print(f"Perhaps you forgot to {Fore.CYAN}mbc login{Fore.RESET}?") return None, None return server, _get_token(server) -def get_token(server: str) -> Tuple[Optional[str], Optional[str]]: +def get_token(server: str) -> tuple[str | None, str | None]: if not server: return get_default_server() if server in config["aliases"]: @@ -46,14 +49,14 @@ def get_token(server: str) -> Tuple[Optional[str], Optional[str]]: return server, _get_token(server) -def _resolve_alias(alias: str) -> Optional[str]: +def _resolve_alias(alias: str) -> str | None: try: return config["aliases"][alias] except KeyError: return None -def _get_token(server: str) -> Optional[str]: +def _get_token(server: str) -> str | None: try: return config["servers"][server] except KeyError: diff --git a/maubot/cli/res/spdx.json.zip b/maubot/cli/res/spdx.json.zip index 4cd4701..98de1b0 100644 Binary files a/maubot/cli/res/spdx.json.zip and b/maubot/cli/res/spdx.json.zip differ diff --git a/maubot/cli/util/spdx.py b/maubot/cli/util/spdx.py index aca303d..69f58b7 100644 --- a/maubot/cli/util/spdx.py +++ b/maubot/cli/util/spdx.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,12 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Optional -import zipfile -import pkg_resources -import json +from __future__ import annotations -spdx_list: Optional[Dict[str, Dict[str, str]]] = None +import json +import zipfile + +import pkg_resources + +spdx_list: dict[str, dict[str, str]] | None = None def load() -> None: @@ -31,13 +33,13 @@ def load() -> None: spdx_list = json.load(file) -def get(id: str) -> Dict[str, str]: +def get(id: str) -> dict[str, str]: if not spdx_list: load() - return spdx_list[id.lower()] + return spdx_list[id] def valid(id: str) -> bool: if not spdx_list: load() - return id.lower() in spdx_list + return id in spdx_list diff --git a/maubot/client.py b/maubot/client.py index 9e3d1a7..b0fde73 100644 --- a/maubot/client.py +++ b/maubot/client.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,87 +13,141 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Iterable, Optional, Set, Callable, Any, Awaitable, Union, TYPE_CHECKING -from os import path +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable, cast +from collections import defaultdict import asyncio import logging from aiohttp import ClientSession -from yarl import URL -from mautrix.errors import MatrixInvalidToken, MatrixRequestError -from mautrix.types import (UserID, SyncToken, FilterID, ContentURI, StrippedStateEvent, Membership, - StateEvent, EventType, Filter, RoomFilter, RoomEventFilter, EventFilter, - PresenceState, StateFilter) from mautrix.client import InternalEventType -from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore +from mautrix.errors import MatrixInvalidToken +from mautrix.types import ( + ContentURI, + DeviceID, + EventFilter, + EventType, + Filter, + FilterID, + Membership, + PresenceState, + RoomEventFilter, + RoomFilter, + StateEvent, + StateFilter, + StrippedStateEvent, + SyncToken, + UserID, +) +from mautrix.util import background_task +from mautrix.util.async_getter_lock import async_getter_lock +from mautrix.util.logging import TraceLogger -from .lib.store_proxy import SyncStoreProxy -from .db import DBClient +from .db import Client as DBClient from .matrix import MaubotMatrixClient try: - from mautrix.crypto import (OlmMachine, StateStore as CryptoStateStore, CryptoStore, - PickleCryptoStore) + from mautrix.crypto import OlmMachine, PgCryptoStore - - class SQLStateStore(BaseSQLStateStore, CryptoStateStore): - pass -except ImportError: - OlmMachine = CryptoStateStore = CryptoStore = PickleCryptoStore = None - SQLStateStore = BaseSQLStateStore - -try: - from mautrix.util.async_db import Database as AsyncDatabase - from mautrix.crypto import PgCryptoStore -except ImportError: - AsyncDatabase = None - PgCryptoStore = None + crypto_import_error = None +except ImportError as e: + OlmMachine = PgCryptoStore = None + crypto_import_error = e if TYPE_CHECKING: + from .__main__ import Maubot from .instance import PluginInstance - from .config import Config - -log = logging.getLogger("maubot.client") -class Client: - log: logging.Logger = None - loop: asyncio.AbstractEventLoop = None - cache: Dict[UserID, 'Client'] = {} +class Client(DBClient): + maubot: "Maubot" = None + cache: dict[UserID, Client] = {} + _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) + log: TraceLogger = logging.getLogger("maubot.client") + http_client: ClientSession = None - global_state_store: Union['BaseSQLStateStore', 'CryptoStateStore'] = SQLStateStore() - crypto_pickle_dir: str = None - crypto_db: 'AsyncDatabase' = None - references: Set['PluginInstance'] - db_instance: DBClient + references: set[PluginInstance] client: MaubotMatrixClient - crypto: Optional['OlmMachine'] - crypto_store: Optional['CryptoStore'] + crypto: OlmMachine | None + crypto_store: PgCryptoStore | None started: bool + sync_ok: bool - remote_displayname: Optional[str] - remote_avatar_url: Optional[ContentURI] + remote_displayname: str | None + remote_avatar_url: ContentURI | None - def __init__(self, db_instance: DBClient) -> None: - self.db_instance = db_instance + def __init__( + self, + id: UserID, + homeserver: str, + access_token: str, + device_id: DeviceID, + enabled: bool = False, + next_batch: SyncToken = "", + filter_id: FilterID = "", + sync: bool = True, + autojoin: bool = True, + online: bool = True, + displayname: str = "disable", + avatar_url: str = "disable", + ) -> None: + super().__init__( + id=id, + homeserver=homeserver, + access_token=access_token, + device_id=device_id, + enabled=bool(enabled), + next_batch=next_batch, + filter_id=filter_id, + sync=bool(sync), + autojoin=bool(autojoin), + online=bool(online), + displayname=displayname, + avatar_url=avatar_url, + ) + self._postinited = False + + def __hash__(self) -> int: + return hash(self.id) + + @classmethod + def init_cls(cls, maubot: "Maubot") -> None: + cls.maubot = maubot + + def _make_client( + self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None + ) -> MaubotMatrixClient: + return MaubotMatrixClient( + mxid=self.id, + base_url=homeserver or self.homeserver, + token=token or self.access_token, + client_session=self.http_client, + log=self.log, + crypto_log=self.log.getChild("crypto"), + loop=self.maubot.loop, + device_id=device_id or self.device_id, + sync_store=self, + state_store=self.maubot.state_store, + ) + + def postinit(self) -> None: + if self._postinited: + raise RuntimeError("postinit() called twice") + self._postinited = True self.cache[self.id] = self - self.log = log.getChild(self.id) + self.log = self.log.getChild(self.id) + self.http_client = ClientSession(loop=self.maubot.loop) self.references = set() self.started = False self.sync_ok = True self.remote_displayname = None self.remote_avatar_url = None - self.client = MaubotMatrixClient(mxid=self.id, base_url=self.homeserver, - token=self.access_token, client_session=self.http_client, - log=self.log, loop=self.loop, device_id=self.device_id, - sync_store=SyncStoreProxy(self.db_instance), - state_store=self.global_state_store) - if OlmMachine and self.device_id and (self.crypto_db or self.crypto_pickle_dir): - self.crypto_store = self._make_crypto_store() - self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store) - self.client.crypto = self.crypto + self.client = self._make_client() + if self.enable_crypto: + self._prepare_crypto() else: self.crypto_store = None self.crypto = None @@ -106,21 +160,56 @@ class Client: self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False)) self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True)) - def _make_crypto_store(self) -> 'CryptoStore': - if self.crypto_db: - return PgCryptoStore(account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db) - elif self.crypto_pickle_dir: - return PickleCryptoStore(account_id=self.id, pickle_key="maubot.crypto", - path=path.join(self.crypto_pickle_dir, f"{self.id}.pickle")) - raise ValueError("Crypto database not configured") - - def _set_sync_ok(self, ok: bool) -> Callable[[Dict[str, Any]], Awaitable[None]]: - async def handler(data: Dict[str, Any]) -> None: + def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]: + async def handler(data: dict[str, Any]) -> None: self.sync_ok = ok return handler - async def start(self, try_n: Optional[int] = 0) -> None: + @property + def enable_crypto(self) -> bool: + if not self.device_id: + return False + elif not OlmMachine: + global crypto_import_error + self.log.warning( + "Client has device ID, but encryption dependencies not installed", + exc_info=crypto_import_error, + ) + # Clear the stack trace after it's logged once to avoid spamming logs + crypto_import_error = None + return False + elif not self.maubot.crypto_db: + self.log.warning("Client has device ID, but crypto database is not prepared") + return False + return True + + def _prepare_crypto(self) -> None: + self.crypto_store = PgCryptoStore( + account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db + ) + self.crypto = OlmMachine( + self.client, + self.crypto_store, + self.maubot.state_store, + log=self.client.crypto_log, + ) + self.client.crypto = self.crypto + + def _remove_crypto_event_handlers(self) -> None: + if not self.crypto: + return + handlers = [ + (InternalEventType.DEVICE_OTK_COUNT, self.crypto.handle_otk_count), + (InternalEventType.DEVICE_LISTS, self.crypto.handle_device_lists), + (EventType.TO_DEVICE_ENCRYPTED, self.crypto.handle_to_device_event), + (EventType.ROOM_KEY_REQUEST, self.crypto.handle_room_key_request), + (EventType.ROOM_MEMBER, self.crypto.handle_member_event), + ] + for event_type, func in handlers: + self.client.remove_event_handler(event_type, func) + + async def start(self, try_n: int | None = 0) -> None: try: if try_n > 0: await asyncio.sleep(try_n * 10) @@ -128,7 +217,21 @@ class Client: except Exception: self.log.exception("Failed to start") - async def _start(self, try_n: Optional[int] = 0) -> None: + async def _start_crypto(self) -> None: + self.log.debug("Enabling end-to-end encryption support") + await self.crypto_store.open() + crypto_device_id = await self.crypto_store.get_device_id() + if crypto_device_id and crypto_device_id != self.device_id: + self.log.warning( + "Mismatching device ID in crypto store and main database, resetting encryption" + ) + await self.crypto_store.delete() + crypto_device_id = None + await self.crypto.load() + if not crypto_device_id: + await self.crypto_store.put_device_id(self.device_id) + + async def _start(self, try_n: int | None = 0) -> None: if not self.enabled: self.log.debug("Not starting disabled client") return @@ -136,53 +239,60 @@ class Client: self.log.warning("Ignoring start() call to started client") return try: - user_id = await self.client.whoami() + await self.client.versions() + whoami = await self.client.whoami() except MatrixInvalidToken as e: self.log.error(f"Invalid token: {e}. Disabling client") - self.db_instance.enabled = False + self.enabled = False + await self.update() return except Exception as e: if try_n >= 8: self.log.exception("Failed to get /account/whoami, disabling client") - self.db_instance.enabled = False + self.enabled = False + await self.update() else: - self.log.warning(f"Failed to get /account/whoami, " - f"retrying in {(try_n + 1) * 10}s: {e}") - _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop) + self.log.warning( + f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}" + ) + background_task.create(self.start(try_n + 1)) return - if user_id != self.id: - self.log.error(f"User ID mismatch: expected {self.id}, but got {user_id}") - self.db_instance.enabled = False + if whoami.user_id != self.id: + self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}") + self.enabled = False + await self.update() + return + elif whoami.device_id and self.device_id and whoami.device_id != self.device_id: + self.log.error( + f"Device ID mismatch: expected {self.device_id}, but got {whoami.device_id}" + ) + self.enabled = False + await self.update() return if not self.filter_id: - self.db_instance.edit(filter_id=await self.client.create_filter(Filter( - room=RoomFilter( - timeline=RoomEventFilter( - limit=50, - lazy_load_members=True, + self.filter_id = await self.client.create_filter( + Filter( + room=RoomFilter( + timeline=RoomEventFilter( + limit=50, + lazy_load_members=True, + ), + state=StateFilter( + lazy_load_members=True, + ), ), - state=StateFilter( - lazy_load_members=True, - ) - ), - presence=EventFilter( - not_types=[EventType.PRESENCE], - ), - ))) + presence=EventFilter( + not_types=[EventType.PRESENCE], + ), + ) + ) + await self.update() if self.displayname != "disable": await self.client.set_displayname(self.displayname) if self.avatar_url != "disable": await self.client.set_avatar_url(self.avatar_url) if self.crypto: - self.log.debug("Enabling end-to-end encryption support") - await self.crypto_store.open() - crypto_device_id = await self.crypto_store.get_device_id() - if crypto_device_id and crypto_device_id != self.device_id: - self.log.warning("Mismatching device ID in crypto store and main database. " - "Encryption may not work.") - await self.crypto.load() - if not crypto_device_id: - await self.crypto_store.put_device_id(self.device_id) + await self._start_crypto() self.start_sync() await self._update_remote_profile() self.started = True @@ -210,24 +320,22 @@ class Client: if self.crypto: await self.crypto_store.close() - def clear_cache(self) -> None: + async def clear_cache(self) -> None: self.stop_sync() - self.db_instance.edit(filter_id="", next_batch="") + self.filter_id = FilterID("") + self.next_batch = SyncToken("") + await self.update() self.start_sync() - def delete(self) -> None: - try: - del self.cache[self.id] - except KeyError: - pass - self.db_instance.delete() - def to_dict(self) -> dict: return { "id": self.id, "homeserver": self.homeserver, "access_token": self.access_token, "device_id": self.device_id, + "fingerprint": ( + self.crypto.account.fingerprint if self.crypto and self.crypto.account else None + ), "enabled": self.enabled, "started": self.started, "sync": self.sync, @@ -241,32 +349,45 @@ class Client: "instances": [instance.to_dict() for instance in self.references], } - @classmethod - def get(cls, user_id: UserID, db_instance: Optional[DBClient] = None) -> Optional['Client']: - try: - return cls.cache[user_id] - except KeyError: - db_instance = db_instance or DBClient.get(user_id) - if not db_instance: - return None - return Client(db_instance) - - @classmethod - def all(cls) -> Iterable['Client']: - return (cls.get(user.id, user) for user in DBClient.all()) - async def _handle_tombstone(self, evt: StateEvent) -> None: + if evt.state_key != "": + return if not evt.content.replacement_room: self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring") return + is_joined = await self.client.state_store.is_joined( + evt.content.replacement_room, + self.client.mxid, + ) + if is_joined: + self.log.debug( + f"Ignoring tombstone from {evt.room_id} to {evt.content.replacement_room} " + f"sent by {evt.sender}: already joined to replacement room" + ) + return + self.log.debug( + f"Following tombstone from {evt.room_id} to {evt.content.replacement_room} " + f"sent by {evt.sender}" + ) _, server = self.client.parse_user_id(evt.sender) - await self.client.join_room(evt.content.replacement_room, servers=[server]) + room_id = await self.client.join_room(evt.content.replacement_room, servers=[server]) + power_levels = await self.client.get_state_event(room_id, EventType.ROOM_POWER_LEVELS) + if power_levels.get_user_level(evt.sender) < power_levels.invite: + self.log.warning( + f"{evt.room_id} was tombstoned into {room_id} by {evt.sender}," + " but the sender doesn't have invite power levels, leaving..." + ) + await self.client.leave_room( + room_id, + f"Followed tombstone from {evt.room_id} by {evt.sender}," + " but sender doesn't have sufficient power level for invites", + ) async def _handle_invite(self, evt: StrippedStateEvent) -> None: if evt.state_key == self.id and evt.content.membership == Membership.INVITE: await self.client.join_room(evt.room_id) - async def update_started(self, started: bool) -> None: + async def update_started(self, started: bool | None) -> None: if started is None or started == self.started: return if started: @@ -274,154 +395,162 @@ class Client: else: await self.stop() - async def update_displayname(self, displayname: str) -> None: + async def update_enabled(self, enabled: bool | None, save: bool = True) -> None: + if enabled is None or enabled == self.enabled: + return + self.enabled = enabled + if save: + await self.update() + + async def update_displayname(self, displayname: str | None, save: bool = True) -> None: if displayname is None or displayname == self.displayname: return - self.db_instance.displayname = displayname + self.displayname = displayname if self.displayname != "disable": await self.client.set_displayname(self.displayname) else: await self._update_remote_profile() + if save: + await self.update() - async def update_avatar_url(self, avatar_url: ContentURI) -> None: + async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None: if avatar_url is None or avatar_url == self.avatar_url: return - self.db_instance.avatar_url = avatar_url + self.avatar_url = avatar_url if self.avatar_url != "disable": await self.client.set_avatar_url(self.avatar_url) else: await self._update_remote_profile() + if save: + await self.update() - async def update_access_details(self, access_token: str, homeserver: str) -> None: + async def update_sync(self, sync: bool | None, save: bool = True) -> None: + if sync is None or self.sync == sync: + return + self.sync = sync + if self.started: + if sync: + self.start_sync() + else: + self.stop_sync() + if save: + await self.update() + + async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None: + if autojoin is None or autojoin == self.autojoin: + return + if autojoin: + self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) + else: + self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite) + self.autojoin = autojoin + if save: + await self.update() + + async def update_online(self, online: bool | None, save: bool = True) -> None: + if online is None or online == self.online: + return + self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE + self.online = online + if save: + await self.update() + + async def update_access_details( + self, + access_token: str | None, + homeserver: str | None, + device_id: str | None = None, + ) -> None: if not access_token and not homeserver: return - elif access_token == self.access_token and homeserver == self.homeserver: + if device_id is None: + device_id = self.device_id + elif not device_id: + device_id = None + if ( + access_token == self.access_token + and homeserver == self.homeserver + and device_id == self.device_id + ): return - new_client = MaubotMatrixClient(mxid=self.id, base_url=homeserver or self.homeserver, - token=access_token or self.access_token, loop=self.loop, - client_session=self.http_client, device_id=self.device_id, - log=self.log, state_store=self.global_state_store) - mxid = await new_client.whoami() - if mxid != self.id: - raise ValueError(f"MXID mismatch: {mxid}") - new_client.sync_store = SyncStoreProxy(self.db_instance) + new_client = self._make_client(homeserver, access_token, device_id) + whoami = await new_client.whoami() + if whoami.user_id != self.id: + raise ValueError(f"MXID mismatch: {whoami.user_id}") + elif whoami.device_id and device_id and whoami.device_id != device_id: + raise ValueError(f"Device ID mismatch: {whoami.device_id}") + new_client.sync_store = self self.stop_sync() + + # TODO this event handler transfer is pretty hacky + self._remove_crypto_event_handlers() + self.client.crypto = None + new_client.event_handlers = self.client.event_handlers + new_client.global_event_handlers = self.client.global_event_handlers + self.client = new_client - self.db_instance.homeserver = homeserver - self.db_instance.access_token = access_token + self.homeserver = homeserver + self.access_token = access_token + self.device_id = device_id + if self.enable_crypto: + self._prepare_crypto() + await self._start_crypto() + else: + self.crypto_store = None + self.crypto = None self.start_sync() async def _update_remote_profile(self) -> None: profile = await self.client.get_profile(self.id) self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url - # region Properties + async def delete(self) -> None: + try: + del self.cache[self.id] + except KeyError: + pass + await super().delete() - @property - def id(self) -> UserID: - return self.db_instance.id + @classmethod + @async_getter_lock + async def get( + cls, + user_id: UserID, + *, + homeserver: str | None = None, + access_token: str | None = None, + device_id: DeviceID | None = None, + ) -> Client | None: + try: + return cls.cache[user_id] + except KeyError: + pass - @property - def homeserver(self) -> str: - return self.db_instance.homeserver + user = cast(cls, await super().get(user_id)) + if user is not None: + user.postinit() + return user - @property - def access_token(self) -> str: - return self.db_instance.access_token + if homeserver and access_token: + user = cls( + user_id, + homeserver=homeserver, + access_token=access_token, + device_id=device_id or "", + ) + await user.insert() + user.postinit() + return user - @property - def device_id(self) -> str: - return self.db_instance.device_id + return None - @property - def enabled(self) -> bool: - return self.db_instance.enabled - - @enabled.setter - def enabled(self, value: bool) -> None: - self.db_instance.enabled = value - - @property - def next_batch(self) -> SyncToken: - return self.db_instance.next_batch - - @property - def filter_id(self) -> FilterID: - return self.db_instance.filter_id - - @property - def sync(self) -> bool: - return self.db_instance.sync - - @sync.setter - def sync(self, value: bool) -> None: - if value == self.db_instance.sync: - return - self.db_instance.sync = value - if self.started: - if value: - self.start_sync() - else: - self.stop_sync() - - @property - def autojoin(self) -> bool: - return self.db_instance.autojoin - - @autojoin.setter - def autojoin(self, value: bool) -> None: - if value == self.db_instance.autojoin: - return - if value: - self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite) - else: - self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite) - self.db_instance.autojoin = value - - @property - def online(self) -> bool: - return self.db_instance.online - - @online.setter - def online(self, value: bool) -> None: - self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE - self.db_instance.online = value - - @property - def displayname(self) -> str: - return self.db_instance.displayname - - @property - def avatar_url(self) -> ContentURI: - return self.db_instance.avatar_url - - # endregion - - -def init(config: 'Config', loop: asyncio.AbstractEventLoop) -> Iterable[Client]: - Client.http_client = ClientSession(loop=loop) - Client.loop = loop - - if OlmMachine: - db_type = config["crypto_database.type"] - if db_type == "default": - db_url = config["database"] - parsed_url = URL(db_url) - if parsed_url.scheme == "sqlite": - Client.crypto_pickle_dir = config["crypto_database.pickle_dir"] - elif parsed_url.scheme == "postgres": - if not PgCryptoStore: - log.warning("Default database is postgres, but asyncpg is not installed. " - "Encryption will not work.") - else: - Client.crypto_db = AsyncDatabase(url=db_url, - upgrade_table=PgCryptoStore.upgrade_table) - elif db_type == "pickle": - Client.crypto_pickle_dir = config["crypto_database.pickle_dir"] - elif db_type == "postgres" and PgCryptoStore: - Client.crypto_db = AsyncDatabase(url=config["crypto_database.postgres_uri"], - upgrade_table=PgCryptoStore.upgrade_table) - else: - raise ValueError("Unsupported crypto database type") - - return Client.all() + @classmethod + async def all(cls) -> AsyncGenerator[Client, None]: + users = await super().all() + user: cls + for user in users: + try: + yield cls.cache[user.id] + except KeyError: + user.postinit() + yield user diff --git a/maubot/config.py b/maubot/config.py index 34466cc..b8e42de 100644 --- a/maubot/config.py +++ b/maubot/config.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,9 +14,10 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . import random -import string -import bcrypt import re +import string + +import bcrypt from mautrix.util.config import BaseFileConfig, ConfigUpdateHelper @@ -31,36 +32,50 @@ class Config(BaseFileConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: base = helper.base copy = helper.copy - copy("database") - copy("crypto_database.type") - copy("crypto_database.postgres_uri") - copy("crypto_database.pickle_dir") + + if "database" in self and self["database"].startswith("sqlite:///"): + helper.base["database"] = self["database"].replace("sqlite:///", "sqlite:") + else: + copy("database") + copy("database_opts") + if isinstance(self["crypto_database"], dict): + if self["crypto_database.type"] == "postgres": + base["crypto_database"] = self["crypto_database.postgres_uri"] + else: + copy("crypto_database") copy("plugin_directories.upload") copy("plugin_directories.load") copy("plugin_directories.trash") - copy("plugin_directories.db") + if "plugin_directories.db" in self: + base["plugin_databases.sqlite"] = self["plugin_directories.db"] + else: + copy("plugin_databases.sqlite") + copy("plugin_databases.postgres") + copy("plugin_databases.postgres_opts") copy("server.hostname") copy("server.port") copy("server.public_url") copy("server.listen") - copy("server.base_path") copy("server.ui_base_path") copy("server.plugin_base_path") copy("server.override_resource_path") - copy("server.appservice_base_path") shared_secret = self["server.unshared_secret"] if shared_secret is None or shared_secret == "generate": base["server.unshared_secret"] = self._new_token() else: base["server.unshared_secret"] = shared_secret - copy("registration_secrets") + if "registration_secrets" in self: + base["homeservers"] = self["registration_secrets"] + else: + copy("homeservers") copy("admins") for username, password in base["admins"].items(): if password and not bcrypt_regex.match(password): if password == "password": password = self._new_token() - base["admins"][username] = bcrypt.hashpw(password.encode("utf-8"), - bcrypt.gensalt()).decode("utf-8") + base["admins"][username] = bcrypt.hashpw( + password.encode("utf-8"), bcrypt.gensalt() + ).decode("utf-8") copy("api_features.login") copy("api_features.plugin") copy("api_features.plugin_upload") diff --git a/maubot/db.py b/maubot/db.py deleted file mode 100644 index 3817882..0000000 --- a/maubot/db.py +++ /dev/null @@ -1,101 +0,0 @@ -# maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU Affero General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Affero General Public License for more details. -# -# You should have received a copy of the GNU Affero General Public License -# along with this program. If not, see . -from typing import Iterable, Optional -import logging -import sys - -from sqlalchemy import Column, String, Boolean, ForeignKey, Text -from sqlalchemy.engine.base import Engine -import sqlalchemy as sql - -from mautrix.types import UserID, FilterID, DeviceID, SyncToken, ContentURI -from mautrix.util.db import Base -from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile - -from .config import Config - - -class DBPlugin(Base): - __tablename__ = "plugin" - - id: str = Column(String(255), primary_key=True) - type: str = Column(String(255), nullable=False) - enabled: bool = Column(Boolean, nullable=False, default=False) - primary_user: UserID = Column(String(255), - ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"), - nullable=False) - config: str = Column(Text, nullable=False, default='') - - @classmethod - def all(cls) -> Iterable['DBPlugin']: - return cls._select_all() - - @classmethod - def get(cls, id: str) -> Optional['DBPlugin']: - return cls._select_one_or_none(cls.c.id == id) - - -class DBClient(Base): - __tablename__ = "client" - - id: UserID = Column(String(255), primary_key=True) - homeserver: str = Column(String(255), nullable=False) - access_token: str = Column(Text, nullable=False) - device_id: DeviceID = Column(String(255), nullable=True) - enabled: bool = Column(Boolean, nullable=False, default=False) - - next_batch: SyncToken = Column(String(255), nullable=False, default="") - filter_id: FilterID = Column(String(255), nullable=False, default="") - - sync: bool = Column(Boolean, nullable=False, default=True) - autojoin: bool = Column(Boolean, nullable=False, default=True) - online: bool = Column(Boolean, nullable=False, default=True) - - displayname: str = Column(String(255), nullable=False, default="") - avatar_url: ContentURI = Column(String(255), nullable=False, default="") - - @classmethod - def all(cls) -> Iterable['DBClient']: - return cls._select_all() - - @classmethod - def get(cls, id: str) -> Optional['DBClient']: - return cls._select_one_or_none(cls.c.id == id) - - -def init(config: Config) -> Engine: - db = sql.create_engine(config["database"]) - Base.metadata.bind = db - - for table in (DBPlugin, DBClient, RoomState, UserProfile): - table.bind(db) - - if not db.has_table("alembic_version"): - log = logging.getLogger("maubot.db") - - if db.has_table("client") and db.has_table("plugin"): - log.warning("alembic_version table not found, but client and plugin tables found. " - "Assuming pre-Alembic database and inserting version.") - db.execute("CREATE TABLE IF NOT EXISTS alembic_version (" - " version_num VARCHAR(32) PRIMARY KEY" - ");") - db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');") - else: - log.critical("alembic_version table not found. " - "Did you forget to `alembic upgrade head`?") - sys.exit(10) - - return db diff --git a/maubot/db/__init__.py b/maubot/db/__init__.py new file mode 100644 index 0000000..68833ce --- /dev/null +++ b/maubot/db/__init__.py @@ -0,0 +1,13 @@ +from mautrix.util.async_db import Database + +from .client import Client +from .instance import DatabaseEngine, Instance +from .upgrade import upgrade_table + + +def init(db: Database) -> None: + for table in (Client, Instance): + table.db = db + + +__all__ = ["upgrade_table", "init", "Client", "Instance", "DatabaseEngine"] diff --git a/maubot/db/client.py b/maubot/db/client.py new file mode 100644 index 0000000..52f3a20 --- /dev/null +++ b/maubot/db/client.py @@ -0,0 +1,114 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar + +from asyncpg import Record +from attr import dataclass + +from mautrix.client import SyncStore +from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID +from mautrix.util.async_db import Database + +fake_db = Database.create("") if TYPE_CHECKING else None + + +@dataclass +class Client(SyncStore): + db: ClassVar[Database] = fake_db + + id: UserID + homeserver: str + access_token: str + device_id: DeviceID + enabled: bool + + next_batch: SyncToken + filter_id: FilterID + + sync: bool + autojoin: bool + online: bool + + displayname: str + avatar_url: ContentURI + + @classmethod + def _from_row(cls, row: Record | None) -> Client | None: + if row is None: + return None + return cls(**row) + + _columns = ( + "id, homeserver, access_token, device_id, enabled, next_batch, filter_id, " + "sync, autojoin, online, displayname, avatar_url" + ) + + @property + def _values(self): + return ( + self.id, + self.homeserver, + self.access_token, + self.device_id, + self.enabled, + self.next_batch, + self.filter_id, + self.sync, + self.autojoin, + self.online, + self.displayname, + self.avatar_url, + ) + + @classmethod + async def all(cls) -> list[Client]: + rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client") + return [cls._from_row(row) for row in rows] + + @classmethod + async def get(cls, id: str) -> Client | None: + q = f"SELECT {cls._columns} FROM client WHERE id=$1" + return cls._from_row(await cls.db.fetchrow(q, id)) + + async def insert(self) -> None: + q = """ + INSERT INTO client ( + id, homeserver, access_token, device_id, enabled, next_batch, filter_id, + sync, autojoin, online, displayname, avatar_url + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) + """ + await self.db.execute(q, *self._values) + + async def put_next_batch(self, next_batch: SyncToken) -> None: + await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id) + self.next_batch = next_batch + + async def get_next_batch(self) -> SyncToken: + return self.next_batch + + async def update(self) -> None: + q = """ + UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5, + next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10, + displayname=$11, avatar_url=$12 + WHERE id=$1 + """ + await self.db.execute(q, *self._values) + + async def delete(self) -> None: + await self.db.execute("DELETE FROM client WHERE id=$1", self.id) diff --git a/maubot/db/instance.py b/maubot/db/instance.py new file mode 100644 index 0000000..5bb3f6a --- /dev/null +++ b/maubot/db/instance.py @@ -0,0 +1,101 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from typing import TYPE_CHECKING, ClassVar +from enum import Enum + +from asyncpg import Record +from attr import dataclass + +from mautrix.types import UserID +from mautrix.util.async_db import Database + +fake_db = Database.create("") if TYPE_CHECKING else None + + +class DatabaseEngine(Enum): + SQLITE = "sqlite" + POSTGRES = "postgres" + + +@dataclass +class Instance: + db: ClassVar[Database] = fake_db + + id: str + type: str + enabled: bool + primary_user: UserID + config_str: str + database_engine: DatabaseEngine | None + + @property + def database_engine_str(self) -> str | None: + return self.database_engine.value if self.database_engine else None + + @classmethod + def _from_row(cls, row: Record | None) -> Instance | None: + if row is None: + return None + data = {**row} + db_engine = data.pop("database_engine", None) + return cls(**data, database_engine=DatabaseEngine(db_engine) if db_engine else None) + + _columns = "id, type, enabled, primary_user, config, database_engine" + + @classmethod + async def all(cls) -> list[Instance]: + q = f"SELECT {cls._columns} FROM instance" + rows = await cls.db.fetch(q) + return [cls._from_row(row) for row in rows] + + @classmethod + async def get(cls, id: str) -> Instance | None: + q = f"SELECT {cls._columns} FROM instance WHERE id=$1" + return cls._from_row(await cls.db.fetchrow(q, id)) + + async def update_id(self, new_id: str) -> None: + await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id) + self.id = new_id + + @property + def _values(self): + return ( + self.id, + self.type, + self.enabled, + self.primary_user, + self.config_str, + self.database_engine_str, + ) + + async def insert(self) -> None: + q = ( + "INSERT INTO instance (id, type, enabled, primary_user, config, database_engine) " + "VALUES ($1, $2, $3, $4, $5, $6)" + ) + await self.db.execute(q, *self._values) + + async def update(self) -> None: + q = """ + UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5, database_engine=$6 + WHERE id=$1 + """ + await self.db.execute(q, *self._values) + + async def delete(self) -> None: + await self.db.execute("DELETE FROM instance WHERE id=$1", self.id) diff --git a/maubot/db/upgrade/__init__.py b/maubot/db/upgrade/__init__.py new file mode 100644 index 0000000..ed96422 --- /dev/null +++ b/maubot/db/upgrade/__init__.py @@ -0,0 +1,5 @@ +from mautrix.util.async_db import UpgradeTable + +upgrade_table = UpgradeTable() + +from . import v01_initial_revision, v02_instance_database_engine diff --git a/maubot/db/upgrade/v01_initial_revision.py b/maubot/db/upgrade/v01_initial_revision.py new file mode 100644 index 0000000..2da8aff --- /dev/null +++ b/maubot/db/upgrade/v01_initial_revision.py @@ -0,0 +1,136 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from mautrix.util.async_db import Connection, Scheme + +from . import upgrade_table + +legacy_version_query = "SELECT version_num FROM alembic_version" +last_legacy_version = "90aa88820eab" + + +@upgrade_table.register(description="Initial asyncpg revision") +async def upgrade_v1(conn: Connection, scheme: Scheme) -> None: + if await conn.table_exists("alembic_version"): + await migrate_legacy_to_v1(conn, scheme) + else: + return await create_v1_tables(conn) + + +async def create_v1_tables(conn: Connection) -> None: + await conn.execute( + """CREATE TABLE client ( + id TEXT PRIMARY KEY, + homeserver TEXT NOT NULL, + access_token TEXT NOT NULL, + device_id TEXT NOT NULL, + enabled BOOLEAN NOT NULL, + + next_batch TEXT NOT NULL, + filter_id TEXT NOT NULL, + + sync BOOLEAN NOT NULL, + autojoin BOOLEAN NOT NULL, + online BOOLEAN NOT NULL, + + displayname TEXT NOT NULL, + avatar_url TEXT NOT NULL + )""" + ) + await conn.execute( + """CREATE TABLE instance ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + enabled BOOLEAN NOT NULL, + primary_user TEXT NOT NULL, + config TEXT NOT NULL, + FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE + )""" + ) + + +async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None: + legacy_version = await conn.fetchval(legacy_version_query) + if legacy_version != last_legacy_version: + raise RuntimeError( + "Legacy database is not on last version. " + "Please upgrade the old database with alembic or drop it completely first." + ) + await conn.execute("ALTER TABLE plugin RENAME TO instance") + await update_state_store(conn, scheme) + if scheme != Scheme.SQLITE: + await varchar_to_text(conn) + await conn.execute("DROP TABLE alembic_version") + + +async def update_state_store(conn: Connection, scheme: Scheme) -> None: + # The Matrix state store already has more or less the correct schema, so set the version + await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)") + await conn.execute("INSERT INTO mx_version (version) VALUES (2)") + if scheme != Scheme.SQLITE: + # Remove old uppercase membership type and recreate it as lowercase + await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT") + await conn.execute("DROP TYPE IF EXISTS membership") + await conn.execute( + "CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')" + ) + await conn.execute( + "ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership " + "USING LOWER(membership)::membership" + ) + else: + # Recreate table to remove CHECK constraint and lowercase everything + await conn.execute( + """CREATE TABLE new_mx_user_profile ( + room_id TEXT, + user_id TEXT, + membership TEXT NOT NULL + CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')), + displayname TEXT, + avatar_url TEXT, + PRIMARY KEY (room_id, user_id) + )""" + ) + await conn.execute( + """ + INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url) + SELECT room_id, user_id, LOWER(membership), displayname, avatar_url + FROM mx_user_profile + """ + ) + await conn.execute("DROP TABLE mx_user_profile") + await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile") + + +async def varchar_to_text(conn: Connection) -> None: + columns_to_adjust = { + "client": ( + "id", + "homeserver", + "device_id", + "next_batch", + "filter_id", + "displayname", + "avatar_url", + ), + "instance": ("id", "type", "primary_user"), + "mx_room_state": ("room_id",), + "mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"), + } + for table, columns in columns_to_adjust.items(): + for column in columns: + await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT') diff --git a/maubot/lib/store_proxy.py b/maubot/db/upgrade/v02_instance_database_engine.py similarity index 61% rename from maubot/lib/store_proxy.py rename to maubot/db/upgrade/v02_instance_database_engine.py index 6e402aa..7d2d7e7 100644 --- a/maubot/lib/store_proxy.py +++ b/maubot/db/upgrade/v02_instance_database_engine.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,16 +13,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from mautrix.client import SyncStore -from mautrix.types import SyncToken +from __future__ import annotations + +from mautrix.util.async_db import Connection + +from . import upgrade_table -class SyncStoreProxy(SyncStore): - def __init__(self, db_instance) -> None: - self.db_instance = db_instance - - async def put_next_batch(self, next_batch: SyncToken) -> None: - self.db_instance.edit(next_batch=next_batch) - - async def get_next_batch(self) -> SyncToken: - return self.db_instance.next_batch +@upgrade_table.register(description="Store instance database engine") +async def upgrade_v2(conn: Connection) -> None: + await conn.execute("ALTER TABLE instance ADD COLUMN database_engine TEXT") diff --git a/maubot/example-config.yaml b/maubot/example-config.yaml new file mode 100644 index 0000000..0a6c8ac --- /dev/null +++ b/maubot/example-config.yaml @@ -0,0 +1,131 @@ +# The full URI to the database. SQLite and Postgres are fully supported. +# Format examples: +# SQLite: sqlite:filename.db +# Postgres: postgresql://username:password@hostname/dbname +database: sqlite:maubot.db + +# Separate database URL for the crypto database. "default" means use the same database as above. +crypto_database: default + +# Additional arguments for asyncpg.create_pool() or sqlite3.connect() +# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool +# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# For sqlite, min_size is used as the connection thread pool size and max_size is ignored. +database_opts: + min_size: 1 + max_size: 10 + +# Configuration for storing plugin .mbp files +plugin_directories: + # The directory where uploaded new plugins should be stored. + upload: ./plugins + # The directories from which plugins should be loaded. + # Duplicate plugin IDs will be moved to the trash. + load: + - ./plugins + # The directory where old plugin versions and conflicting plugins should be moved. + # Set to "delete" to delete files immediately. + trash: ./trash + +# Configuration for storing plugin databases +plugin_databases: + # The directory where SQLite plugin databases should be stored. + sqlite: ./plugins + # The connection URL for plugin databases. If null, all plugins will get SQLite databases. + # If set, plugins using the new asyncpg interface will get a Postgres connection instead. + # Plugins using the legacy SQLAlchemy interface will always get a SQLite connection. + # + # To use the same connection pool as the default database, set to "default" + # (the default database above must be postgres to do this). + # + # When enabled, maubot will create separate Postgres schemas in the database for each plugin. + # To view schemas in psql, use `\dn`. To view enter and interact with a specific schema, + # use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal + # SQL queries/psql commands. + postgres: null + # Maximum number of connections per plugin instance. + postgres_max_conns_per_plugin: 3 + # Overrides for the default database_opts when using a non-"default" postgres connection string. + postgres_opts: {} + +server: + # The IP and port to listen to. + hostname: 0.0.0.0 + port: 29316 + # Public base URL where the server is visible. + public_url: https://example.com + # The base path for the UI. + ui_base_path: /_matrix/maubot + # The base path for plugin endpoints. The instance ID will be appended directly. + plugin_base_path: /_matrix/maubot/plugin/ + # Override path from where to load UI resources. + # Set to false to using pkg_resources to find the path. + override_resource_path: false + # The shared secret to sign API access tokens. + # Set to "generate" to generate and save a new token at startup. + unshared_secret: generate + +# Known homeservers. This is required for the `mbc auth` command and also allows +# more convenient access from the management UI. This is not required to create +# clients in the management UI, since you can also just type the homeserver URL +# into the box there. +homeservers: + matrix.org: + # Client-server API URL + url: https://matrix-client.matrix.org + # registration_shared_secret from synapse config + # You can leave this empty if you don't have access to the homeserver. + # When this is empty, `mbc auth --register` won't work, but `mbc auth` (login) will. + secret: null + +# List of administrator users. Each key is a username and the value is the password. +# Plaintext passwords will be bcrypted on startup. Set empty password to prevent normal login. +# Root is a special user that can't have a password and will always exist. +admins: + root: "" + +# API feature switches. +api_features: + login: true + plugin: true + plugin_upload: true + instance: true + instance_database: true + client: true + client_proxy: true + client_auth: true + dev_open: true + log: true + +# Python logging configuration. +# +# See section 16.7.2 of the Python documentation for more info: +# https://docs.python.org/3.6/library/logging.config.html#configuration-dictionary-schema +logging: + version: 1 + formatters: + colored: + (): maubot.lib.color_log.ColorFormatter + format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" + normal: + format: "[%(asctime)s] [%(levelname)s@%(name)s] %(message)s" + handlers: + file: + class: logging.handlers.RotatingFileHandler + formatter: normal + filename: ./maubot.log + maxBytes: 10485760 + backupCount: 10 + console: + class: logging.StreamHandler + formatter: colored + loggers: + maubot: + level: DEBUG + mau: + level: DEBUG + aiohttp: + level: INFO + root: + level: DEBUG + handlers: [file, console] diff --git a/maubot/handlers/__init__.py b/maubot/handlers/__init__.py index 1d9da7e..e8567a2 100644 --- a/maubot/handlers/__init__.py +++ b/maubot/handlers/__init__.py @@ -1 +1 @@ -from . import event, command, web +from . import command, event, web diff --git a/maubot/handlers/command.py b/maubot/handlers/command.py index bee5509..27e6547 100644 --- a/maubot/handlers/command.py +++ b/maubot/handlers/command.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,29 +13,46 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import (Union, Callable, Sequence, Pattern, Awaitable, NewType, Optional, Any, List, - Dict, Tuple, Set, Iterable) +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + NewType, + Optional, + Pattern, + Sequence, + Set, + Tuple, + Union, +) from abc import ABC, abstractmethod import asyncio import functools import inspect import re -from mautrix.types import MessageType, EventType +from mautrix.types import EventType, MessageType from ..matrix import MaubotMessageEvent from . import event PrefixType = Optional[Union[str, Callable[[], str], Callable[[Any], str]]] -AliasesType = Union[List[str], Tuple[str, ...], Set[str], Callable[[str], bool], - Callable[[Any, str], bool]] -CommandHandlerFunc = NewType("CommandHandlerFunc", - Callable[[MaubotMessageEvent, Any], Awaitable[Any]]) -CommandHandlerDecorator = NewType("CommandHandlerDecorator", - Callable[[Union['CommandHandler', CommandHandlerFunc]], - 'CommandHandler']) -PassiveCommandHandlerDecorator = NewType("PassiveCommandHandlerDecorator", - Callable[[CommandHandlerFunc], CommandHandlerFunc]) +AliasesType = Union[ + List[str], Tuple[str, ...], Set[str], Callable[[str], bool], Callable[[Any, str], bool] +] +CommandHandlerFunc = NewType( + "CommandHandlerFunc", Callable[[MaubotMessageEvent, Any], Awaitable[Any]] +) +CommandHandlerDecorator = NewType( + "CommandHandlerDecorator", + Callable[[Union["CommandHandler", CommandHandlerFunc]], "CommandHandler"], +) +PassiveCommandHandlerDecorator = NewType( + "PassiveCommandHandlerDecorator", Callable[[CommandHandlerFunc], CommandHandlerFunc] +) def _split_in_two(val: str, split_by: str) -> List[str]: @@ -55,7 +72,7 @@ class CommandHandler: self.__mb_must_consume_args__: bool = True self.__mb_arg_fallthrough__: bool = True self.__mb_event_handler__: bool = True - self.__mb_event_type__: EventType = EventType.ROOM_MESSAGE + self.__mb_event_types__: set[EventType] = {EventType.ROOM_MESSAGE} self.__mb_msgtypes__: Iterable[MessageType] = (MessageType.TEXT,) self.__bound_copies__: Dict[Any, CommandHandler] = {} self.__bound_instance__: Any = None @@ -67,15 +84,27 @@ class CommandHandler: return self.__bound_copies__[instance] except KeyError: new_ch = type(self)(self.__mb_func__) - keys = ["parent", "subcommands", "arguments", "help", "get_name", "is_command_match", - "require_subcommand", "arg_fallthrough", "event_handler", "event_type", - "msgtypes"] + keys = [ + "parent", + "subcommands", + "arguments", + "help", + "get_name", + "is_command_match", + "require_subcommand", + "must_consume_args", + "arg_fallthrough", + "event_handler", + "event_types", + "msgtypes", + ] for key in keys: key = f"__mb_{key}__" setattr(new_ch, key, getattr(self, key)) new_ch.__bound_instance__ = instance - new_ch.__mb_subcommands__ = [subcmd.__get__(instance, instancetype) - for subcmd in self.__mb_subcommands__] + new_ch.__mb_subcommands__ = [ + subcmd.__get__(instance, instancetype) for subcmd in self.__mb_subcommands__ + ] self.__bound_copies__[instance] = new_ch return new_ch @@ -83,8 +112,13 @@ class CommandHandler: def __command_match_unset(self, val: str) -> bool: raise NotImplementedError("Hmm") - async def __call__(self, evt: MaubotMessageEvent, *, _existing_args: Dict[str, Any] = None, - remaining_val: str = None) -> Any: + async def __call__( + self, + evt: MaubotMessageEvent, + *, + _existing_args: Dict[str, Any] = None, + remaining_val: str = None, + ) -> Any: if evt.sender == evt.client.mxid or evt.content.msgtype not in self.__mb_msgtypes__: return if remaining_val is None: @@ -120,21 +154,25 @@ class CommandHandler: return await self.__mb_func__(self.__bound_instance__, evt, **call_args) return await self.__mb_func__(evt, **call_args) - async def __call_subcommand__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], - remaining_val: str) -> Tuple[bool, Any]: + async def __call_subcommand__( + self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str + ) -> Tuple[bool, Any]: command, remaining_val = _split_in_two(remaining_val.strip(), " ") for subcommand in self.__mb_subcommands__: if subcommand.__mb_is_command_match__(subcommand.__bound_instance__, command): - return True, await subcommand(evt, _existing_args=call_args, - remaining_val=remaining_val) + return True, await subcommand( + evt, _existing_args=call_args, remaining_val=remaining_val + ) return False, None - async def __parse_args__(self, evt: MaubotMessageEvent, call_args: Dict[str, Any], - remaining_val: str) -> Tuple[bool, str]: + async def __parse_args__( + self, evt: MaubotMessageEvent, call_args: Dict[str, Any], remaining_val: str + ) -> Tuple[bool, str]: for arg in self.__mb_arguments__: try: - remaining_val, call_args[arg.name] = arg.match(remaining_val.strip(), evt=evt, - instance=self.__bound_instance__) + remaining_val, call_args[arg.name] = arg.match( + remaining_val.strip(), evt=evt, instance=self.__bound_instance__ + ) if arg.required and call_args[arg.name] is None: raise ValueError("Argument required") except ArgumentSyntaxError as e: @@ -155,8 +193,9 @@ class CommandHandler: @property def __mb_usage_args__(self) -> str: - arg_usage = " ".join(f"<{arg.label}>" if arg.required else f"[{arg.label}]" - for arg in self.__mb_arguments__) + arg_usage = " ".join( + f"<{arg.label}>" if arg.required else f"[{arg.label}]" for arg in self.__mb_arguments__ + ) if self.__mb_subcommands__ and self.__mb_arg_fallthrough__: arg_usage += " " + self.__mb_usage_subcommand__ return arg_usage @@ -172,15 +211,19 @@ class CommandHandler: @property def __mb_prefix__(self) -> str: if self.__mb_parent__: - return (f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} " - f"{self.__mb_name__}") + return ( + f"!{self.__mb_parent__.__mb_get_name__(self.__bound_instance__)} " + f"{self.__mb_name__}" + ) return f"!{self.__mb_name__}" @property def __mb_usage_inline__(self) -> str: if not self.__mb_arg_fallthrough__: - return (f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n" - f"* {self.__mb_name__} {self.__mb_usage_subcommand__}") + return ( + f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}\n" + f"* {self.__mb_name__} {self.__mb_usage_subcommand__}" + ) return f"* {self.__mb_name__} {self.__mb_usage_args__} - {self.__mb_help__}" @property @@ -192,8 +235,10 @@ class CommandHandler: if not self.__mb_arg_fallthrough__: if not self.__mb_arguments__: return f"**Usage:** {self.__mb_prefix__} [subcommand] [...]" - return (f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" - f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}") + return ( + f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" + f" _OR_ {self.__mb_prefix__} {self.__mb_usage_subcommand__}" + ) return f"**Usage:** {self.__mb_prefix__} {self.__mb_usage_args__}" @property @@ -202,14 +247,25 @@ class CommandHandler: return f"{self.__mb_usage_without_subcommands__} \n{self.__mb_subcommands_list__}" return self.__mb_usage_without_subcommands__ - def subcommand(self, name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, - required_subcommand: bool = True, arg_fallthrough: bool = True, - ) -> CommandHandlerDecorator: + def subcommand( + self, + name: PrefixType = None, + *, + help: str = None, + aliases: AliasesType = None, + required_subcommand: bool = True, + arg_fallthrough: bool = True, + ) -> CommandHandlerDecorator: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) - new(name, help=help, aliases=aliases, require_subcommand=required_subcommand, - arg_fallthrough=arg_fallthrough)(func) + new( + name, + help=help, + aliases=aliases, + require_subcommand=required_subcommand, + arg_fallthrough=arg_fallthrough, + )(func) func.__mb_parent__ = self func.__mb_event_handler__ = False self.__mb_subcommands__.append(func) @@ -218,10 +274,17 @@ class CommandHandler: return decorator -def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = None, - event_type: EventType = EventType.ROOM_MESSAGE, msgtypes: Iterable[MessageType] = None, - require_subcommand: bool = True, arg_fallthrough: bool = True, - must_consume_args: bool = True) -> CommandHandlerDecorator: +def new( + name: PrefixType = None, + *, + help: str = None, + aliases: AliasesType = None, + event_type: EventType = EventType.ROOM_MESSAGE, + msgtypes: Iterable[MessageType] = None, + require_subcommand: bool = True, + arg_fallthrough: bool = True, + must_consume_args: bool = True, +) -> CommandHandlerDecorator: def decorator(func: Union[CommandHandler, CommandHandlerFunc]) -> CommandHandler: if not isinstance(func, CommandHandler): func = CommandHandler(func) @@ -235,15 +298,16 @@ def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = Non else: func.__mb_get_name__ = lambda self: name else: - func.__mb_get_name__ = lambda self: func.__name__ + func.__mb_get_name__ = lambda self: func.__mb_func__.__name__.replace("_", "-") if callable(aliases): if len(inspect.getfullargspec(aliases).args) == 1: func.__mb_is_command_match__ = lambda self, val: aliases(val) else: func.__mb_is_command_match__ = aliases elif isinstance(aliases, (list, set, tuple)): - func.__mb_is_command_match__ = lambda self, val: (val == func.__mb_get_name__(self) - or val in aliases) + func.__mb_is_command_match__ = lambda self, val: ( + val == func.__mb_get_name__(self) or val in aliases + ) else: func.__mb_is_command_match__ = lambda self, val: val == func.__mb_get_name__(self) # Decorators are executed last to first, so we reverse the argument list. @@ -251,7 +315,7 @@ def new(name: PrefixType = None, *, help: str = None, aliases: AliasesType = Non func.__mb_require_subcommand__ = require_subcommand func.__mb_arg_fallthrough__ = arg_fallthrough func.__mb_must_consume_args__ = must_consume_args - func.__mb_event_type__ = event_type + func.__mb_event_types__ = {event_type} if msgtypes: func.__mb_msgtypes__ = msgtypes return func @@ -267,8 +331,9 @@ class ArgumentSyntaxError(ValueError): class Argument(ABC): - def __init__(self, name: str, label: str = None, *, required: bool = False, - pass_raw: bool = False) -> None: + def __init__( + self, name: str, label: str = None, *, required: bool = False, pass_raw: bool = False + ) -> None: self.name = name self.label = label or name self.required = required @@ -286,8 +351,15 @@ class Argument(ABC): class RegexArgument(Argument): - def __init__(self, name: str, label: str = None, *, required: bool = False, - pass_raw: bool = False, matches: str = None) -> None: + def __init__( + self, + name: str, + label: str = None, + *, + required: bool = False, + pass_raw: bool = False, + matches: str = None, + ) -> None: super().__init__(name, label, required=required, pass_raw=pass_raw) matches = f"^{matches}" if self.pass_raw else f"^{matches}$" self.regex = re.compile(matches) @@ -298,14 +370,23 @@ class RegexArgument(Argument): val = re.split(r"\s", val, 1)[0] match = self.regex.match(val) if match: - return (orig_val[:match.start()] + orig_val[match.end():], - match.groups() or val[match.start():match.end()]) + return ( + orig_val[: match.start()] + orig_val[match.end() :], + match.groups() or val[match.start() : match.end()], + ) return orig_val, None class CustomArgument(Argument): - def __init__(self, name: str, label: str = None, *, required: bool = False, - pass_raw: bool = False, matcher: Callable[[str], Any]) -> None: + def __init__( + self, + name: str, + label: str = None, + *, + required: bool = False, + pass_raw: bool = False, + matcher: Callable[[str], Any], + ) -> None: super().__init__(name, label, required=required, pass_raw=pass_raw) self.matcher = matcher @@ -316,7 +397,7 @@ class CustomArgument(Argument): val = re.split(r"\s", val, 1)[0] res = self.matcher(val) if res is not None: - return orig_val[len(val):], res + return orig_val[len(val) :], res return orig_val, None @@ -325,12 +406,18 @@ class SimpleArgument(Argument): if self.pass_raw: return "", val res = re.split(r"\s", val, 1)[0] - return val[len(res):], res + return val[len(res) :], res -def argument(name: str, label: str = None, *, required: bool = True, matches: Optional[str] = None, - parser: Optional[Callable[[str], Any]] = None, pass_raw: bool = False - ) -> CommandHandlerDecorator: +def argument( + name: str, + label: str = None, + *, + required: bool = True, + matches: Optional[str] = None, + parser: Optional[Callable[[str], Any]] = None, + pass_raw: bool = False, +) -> CommandHandlerDecorator: if matches: return RegexArgument(name, label, required=required, matches=matches, pass_raw=pass_raw) elif parser: @@ -339,11 +426,17 @@ def argument(name: str, label: str = None, *, required: bool = True, matches: Op return SimpleArgument(name, label, required=required, pass_raw=pass_raw) -def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (MessageType.TEXT,), - field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body, - event_type: EventType = EventType.ROOM_MESSAGE, multiple: bool = False, - case_insensitive: bool = False, multiline: bool = False, dot_all: bool = False - ) -> PassiveCommandHandlerDecorator: +def passive( + regex: Union[str, Pattern], + *, + msgtypes: Sequence[MessageType] = (MessageType.TEXT,), + field: Callable[[MaubotMessageEvent], str] = lambda evt: evt.content.body, + event_type: EventType = EventType.ROOM_MESSAGE, + multiple: bool = False, + case_insensitive: bool = False, + multiline: bool = False, + dot_all: bool = False, +) -> PassiveCommandHandlerDecorator: if not isinstance(regex, Pattern): flags = re.RegexFlag.UNICODE if case_insensitive: @@ -372,12 +465,14 @@ def passive(regex: Union[str, Pattern], *, msgtypes: Sequence[MessageType] = (Me return data = field(evt) if multiple: - val = [(data[match.pos:match.endpos], *match.groups()) - for match in regex.finditer(data)] + val = [ + (data[match.pos : match.endpos], *match.groups()) + for match in regex.finditer(data) + ] else: match = regex.search(data) if match: - val = (data[match.pos:match.endpos], *match.groups()) + val = (data[match.pos : match.endpos], *match.groups()) else: val = None if val: diff --git a/maubot/handlers/event.py b/maubot/handlers/event.py index be02706..9be89b1 100644 --- a/maubot/handlers/event.py +++ b/maubot/handlers/event.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,22 +13,26 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Callable, Union, NewType +from __future__ import annotations + +from typing import Callable, NewType -from mautrix.types import EventType from mautrix.client import EventHandler, InternalEventType +from mautrix.types import EventType EventHandlerDecorator = NewType("EventHandlerDecorator", Callable[[EventHandler], EventHandler]) -def on(var: Union[EventType, InternalEventType, EventHandler] - ) -> Union[EventHandlerDecorator, EventHandler]: +def on(var: EventType | InternalEventType | EventHandler) -> EventHandlerDecorator | EventHandler: def decorator(func: EventHandler) -> EventHandler: func.__mb_event_handler__ = True if isinstance(var, (EventType, InternalEventType)): - func.__mb_event_type__ = var + if hasattr(func, "__mb_event_types__"): + func.__mb_event_types__.add(var) + else: + func.__mb_event_types__ = {var} else: - func.__mb_event_type__ = EventType.ALL + func.__mb_event_types__ = {EventType.ALL} return func diff --git a/maubot/handlers/web.py b/maubot/handlers/web.py index cf53d68..f170124 100644 --- a/maubot/handlers/web.py +++ b/maubot/handlers/web.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,9 +13,9 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Callable, Any, Awaitable +from typing import Any, Awaitable, Callable -from aiohttp import web, hdrs +from aiohttp import hdrs, web WebHandler = Callable[[web.Request], Awaitable[web.StreamResponse]] WebHandlerDecorator = Callable[[WebHandler], WebHandler] diff --git a/maubot/instance.py b/maubot/instance.py index 22bca17..8427e3c 100644 --- a/maubot/instance.py +++ b/maubot/instance.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,58 +13,92 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List, Optional, Iterable, TYPE_CHECKING -from asyncio import AbstractEventLoop -import os.path -import logging +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, AsyncGenerator, cast +from collections import defaultdict +import asyncio +import inspect import io +import logging +import os.path -from ruamel.yaml.comments import CommentedMap from ruamel.yaml import YAML -import sqlalchemy as sql +from ruamel.yaml.comments import CommentedMap -from mautrix.util.config import BaseProxyConfig, RecursiveDict from mautrix.types import UserID +from mautrix.util import background_task +from mautrix.util.async_db import Database, Scheme, UpgradeTable +from mautrix.util.async_getter_lock import async_getter_lock +from mautrix.util.config import BaseProxyConfig, RecursiveDict +from mautrix.util.logging import TraceLogger -from .db import DBPlugin -from .config import Config from .client import Client -from .loader import PluginLoader, ZippedPluginLoader +from .db import DatabaseEngine, Instance as DBInstance +from .lib.optionalalchemy import Engine, MetaData, create_engine +from .lib.plugin_db import ProxyPostgresDatabase +from .loader import DatabaseType, PluginLoader, ZippedPluginLoader from .plugin_base import Plugin if TYPE_CHECKING: - from .server import MaubotServer, PluginWebApp + from .__main__ import Maubot + from .server import PluginWebApp -log = logging.getLogger("maubot.instance") +log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance")) +db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db")) yaml = YAML() yaml.indent(4) yaml.width = 200 -class PluginInstance: - webserver: 'MaubotServer' = None - mb_config: Config = None - loop: AbstractEventLoop = None - cache: Dict[str, 'PluginInstance'] = {} - plugin_directories: List[str] = [] +class PluginInstance(DBInstance): + maubot: "Maubot" = None + cache: dict[str, PluginInstance] = {} + plugin_directories: list[str] = [] + _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock()) log: logging.Logger - loader: PluginLoader - client: Client - plugin: Plugin - config: BaseProxyConfig - base_cfg: Optional[RecursiveDict[CommentedMap]] - base_cfg_str: Optional[str] - inst_db: sql.engine.Engine - inst_db_tables: Dict[str, sql.Table] - inst_webapp: Optional['PluginWebApp'] - inst_webapp_url: Optional[str] + loader: PluginLoader | None + client: Client | None + plugin: Plugin | None + config: BaseProxyConfig | None + base_cfg: RecursiveDict[CommentedMap] | None + base_cfg_str: str | None + inst_db: sql.engine.Engine | Database | None + inst_db_tables: dict | None + inst_webapp: PluginWebApp | None + inst_webapp_url: str | None started: bool - def __init__(self, db_instance: DBPlugin): - self.db_instance = db_instance + def __init__( + self, + id: str, + type: str, + enabled: bool, + primary_user: UserID, + config: str = "", + database_engine: DatabaseEngine | None = None, + ) -> None: + super().__init__( + id=id, + type=type, + enabled=bool(enabled), + primary_user=primary_user, + config_str=config, + database_engine=database_engine, + ) + + def __hash__(self) -> int: + return hash(self.id) + + @classmethod + def init_cls(cls, maubot: "Maubot") -> None: + cls.maubot = maubot + + def postinit(self) -> None: self.log = log.getChild(self.id) + self.cache[self.id] = self self.config = None self.started = False self.loader = None @@ -76,7 +110,6 @@ class PluginInstance: self.inst_webapp_url = None self.base_cfg = None self.base_cfg_str = None - self.cache[self.id] = self def to_dict(self) -> dict: return { @@ -85,44 +118,144 @@ class PluginInstance: "enabled": self.enabled, "started": self.started, "primary_user": self.primary_user, - "config": self.db_instance.config, + "config": self.config_str, "base_config": self.base_cfg_str, - "database": (self.inst_db is not None - and self.mb_config["api_features.instance_database"]), + "database": ( + self.inst_db is not None and self.maubot.config["api_features.instance_database"] + ), + "database_interface": self.loader.meta.database_type_str if self.loader else "unknown", + "database_engine": self.database_engine_str, } - def get_db_tables(self) -> Dict[str, sql.Table]: - if not self.inst_db_tables: - metadata = sql.MetaData() - metadata.reflect(self.inst_db) - self.inst_db_tables = metadata.tables + def _introspect_sqlalchemy(self) -> dict: + metadata = MetaData() + metadata.reflect(self.inst_db) + return { + table.name: { + "columns": { + column.name: { + "type": str(column.type), + "unique": column.unique or False, + "default": column.default, + "nullable": column.nullable, + "primary": column.primary_key, + } + for column in table.columns + }, + } + for table in metadata.tables.values() + } + + async def _introspect_sqlite(self) -> dict: + q = """ + SELECT + m.name AS table_name, + p.cid AS col_id, + p.name AS column_name, + p.type AS data_type, + p.pk AS is_primary, + p.dflt_value AS column_default, + p.[notnull] AS is_nullable + FROM sqlite_master m + LEFT JOIN pragma_table_info((m.name)) p + WHERE m.type = 'table' + ORDER BY table_name, col_id + """ + data = await self.inst_db.fetch(q) + tables = defaultdict(lambda: {"columns": {}}) + for column in data: + table_name = column["table_name"] + col_name = column["column_name"] + tables[table_name]["columns"][col_name] = { + "type": column["data_type"], + "nullable": bool(column["is_nullable"]), + "default": column["column_default"], + "primary": bool(column["is_primary"]), + # TODO uniqueness? + } + return tables + + async def _introspect_postgres(self) -> dict: + assert isinstance(self.inst_db, ProxyPostgresDatabase) + q = """ + SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default, + tc.constraint_type + FROM information_schema.columns col + LEFT JOIN information_schema.constraint_column_usage ccu + ON ccu.column_name=col.column_name + LEFT JOIN information_schema.table_constraints tc + ON col.table_name=tc.table_name + AND col.table_schema=tc.table_schema + AND ccu.constraint_name=tc.constraint_name + AND ccu.constraint_schema=tc.constraint_schema + AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE') + WHERE col.table_schema=$1 + """ + data = await self.inst_db.fetch(q, self.inst_db.schema_name) + tables = defaultdict(lambda: {"columns": {}}) + for column in data: + table_name = column["table_name"] + col_name = column["column_name"] + tables[table_name]["columns"].setdefault( + col_name, + { + "type": column["data_type"], + "nullable": column["is_nullable"], + "default": column["column_default"], + "primary": False, + "unique": False, + }, + ) + if column["constraint_type"] == "PRIMARY KEY": + tables[table_name]["columns"][col_name]["primary"] = True + elif column["constraint_type"] == "UNIQUE": + tables[table_name]["columns"][col_name]["unique"] = True + return tables + + async def get_db_tables(self) -> dict: + if self.inst_db_tables is None: + if isinstance(self.inst_db, Engine): + self.inst_db_tables = self._introspect_sqlalchemy() + elif self.inst_db.scheme == Scheme.SQLITE: + self.inst_db_tables = await self._introspect_sqlite() + else: + self.inst_db_tables = await self._introspect_postgres() return self.inst_db_tables - def load(self) -> bool: + async def load(self) -> bool: if not self.loader: try: self.loader = PluginLoader.find(self.type) except KeyError: self.log.error(f"Failed to find loader for type {self.type}") - self.db_instance.enabled = False + await self.update_enabled(False) return False if not self.client: - self.client = Client.get(self.primary_user) + self.client = await Client.get(self.primary_user) if not self.client: self.log.error(f"Failed to get client for user {self.primary_user}") - self.db_instance.enabled = False + await self.update_enabled(False) return False - if self.loader.meta.database: - db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id) - self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db") if self.loader.meta.webapp: - self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id) + self.enable_webapp() self.log.debug("Plugin instance dependencies loaded") self.loader.references.add(self) self.client.references.add(self) return True - def delete(self) -> None: + def enable_webapp(self) -> None: + self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id) + + def disable_webapp(self) -> None: + self.maubot.server.remove_instance_webapp(self.id) + self.inst_webapp = None + self.inst_webapp_url = None + + @property + def _sqlite_db_path(self) -> str: + return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db") + + async def delete(self) -> None: if self.loader is not None: self.loader.references.remove(self) if self.client is not None: @@ -131,22 +264,89 @@ class PluginInstance: del self.cache[self.id] except KeyError: pass - self.db_instance.delete() + await super().delete() if self.inst_db: - self.inst_db.dispose() - ZippedPluginLoader.trash( - os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"), - reason="deleted") + await self.stop_database() + await self.delete_database() if self.inst_webapp: - self.webserver.remove_instance_webapp(self.id) + self.disable_webapp() def load_config(self) -> CommentedMap: - return yaml.load(self.db_instance.config) + return yaml.load(self.config_str) def save_config(self, data: RecursiveDict[CommentedMap]) -> None: buf = io.StringIO() yaml.dump(data, buf) - self.db_instance.config = buf.getvalue() + val = buf.getvalue() + if val != self.config_str: + self.config_str = val + self.log.debug("Creating background task to save updated config") + background_task.create(self.update()) + + async def start_database( + self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True + ) -> None: + if self.loader.meta.database_type == DatabaseType.SQLALCHEMY: + if self.database_engine is None: + await self.update_db_engine(DatabaseEngine.SQLITE) + elif self.database_engine == DatabaseEngine.POSTGRES: + raise RuntimeError( + "Instance database engine is marked as Postgres, but plugin uses legacy " + "database interface, which doesn't support postgres." + ) + self.inst_db = create_engine(f"sqlite:///{self._sqlite_db_path}") + elif self.loader.meta.database_type == DatabaseType.ASYNCPG: + if self.database_engine is None: + if os.path.exists(self._sqlite_db_path) or not self.maubot.plugin_postgres_db: + await self.update_db_engine(DatabaseEngine.SQLITE) + else: + await self.update_db_engine(DatabaseEngine.POSTGRES) + instance_db_log = db_log.getChild(self.id) + if self.database_engine == DatabaseEngine.POSTGRES: + if not self.maubot.plugin_postgres_db: + raise RuntimeError( + "Instance database engine is marked as Postgres, but this maubot isn't " + "configured to support Postgres for plugin databases" + ) + self.inst_db = ProxyPostgresDatabase( + pool=self.maubot.plugin_postgres_db, + instance_id=self.id, + max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"], + upgrade_table=upgrade_table, + log=instance_db_log, + ) + else: + self.inst_db = Database.create( + f"sqlite:{self._sqlite_db_path}", + upgrade_table=upgrade_table, + log=instance_db_log, + ) + if actually_start: + await self.inst_db.start() + else: + raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}") + + async def stop_database(self) -> None: + if isinstance(self.inst_db, Database): + await self.inst_db.stop() + elif isinstance(self.inst_db, Engine): + self.inst_db.dispose() + else: + raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}") + + async def delete_database(self) -> None: + if self.loader.meta.database_type == DatabaseType.SQLALCHEMY: + ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted") + elif self.loader.meta.database_type == DatabaseType.ASYNCPG: + if self.inst_db is None: + await self.start_database(None, actually_start=False) + if isinstance(self.inst_db, ProxyPostgresDatabase): + await self.inst_db.delete() + else: + ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted") + else: + raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}") + self.inst_db = None async def start(self) -> None: if self.started: @@ -157,9 +357,22 @@ class PluginInstance: return if not self.client or not self.loader: self.log.warning("Missing plugin instance dependencies, attempting to load...") - if not self.load(): + if not await self.load(): return cls = await self.loader.load() + if self.loader.meta.webapp and self.inst_webapp is None: + self.log.debug("Enabling webapp after plugin meta reload") + self.enable_webapp() + elif not self.loader.meta.webapp and self.inst_webapp is not None: + self.log.debug("Disabling webapp after plugin meta reload") + self.disable_webapp() + if self.loader.meta.database: + try: + await self.start_database(cls.get_db_upgrade_table()) + except Exception: + self.log.exception("Failed to start instance database") + await self.update_enabled(False) + return config_class = cls.get_config_class() if config_class: try: @@ -174,23 +387,35 @@ class PluginInstance: if self.base_cfg: base_cfg_func = self.base_cfg.clone else: + def base_cfg_func() -> None: return None + self.config = config_class(self.load_config, base_cfg_func, self.save_config) - self.plugin = cls(client=self.client.client, loop=self.loop, http=self.client.http_client, - instance_id=self.id, log=self.log, config=self.config, - database=self.inst_db, webapp=self.inst_webapp, - webapp_url=self.inst_webapp_url) + self.plugin = cls( + client=self.client.client, + loop=self.maubot.loop, + http=self.client.http_client, + instance_id=self.id, + log=self.log, + config=self.config, + database=self.inst_db, + loader=self.loader, + webapp=self.inst_webapp, + webapp_url=self.inst_webapp_url, + ) try: await self.plugin.internal_start() except Exception: self.log.exception("Failed to start instance") - self.db_instance.enabled = False + await self.update_enabled(False) return self.started = True self.inst_db_tables = None - self.log.info(f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} " - f"with user {self.client.id}") + self.log.info( + f"Started instance of {self.loader.meta.id} v{self.loader.meta.version} " + f"with user {self.client.id}" + ) async def stop(self) -> None: if not self.started: @@ -203,63 +428,58 @@ class PluginInstance: except Exception: self.log.exception("Failed to stop instance") self.plugin = None + if self.inst_db: + try: + await self.stop_database() + except Exception: + self.log.exception("Failed to stop instance database") self.inst_db_tables = None - @classmethod - def get(cls, instance_id: str, db_instance: Optional[DBPlugin] = None - ) -> Optional['PluginInstance']: - try: - return cls.cache[instance_id] - except KeyError: - db_instance = db_instance or DBPlugin.get(instance_id) - if not db_instance: - return None - return PluginInstance(db_instance) + async def update_id(self, new_id: str | None) -> None: + if new_id is not None and new_id.lower() != self.id: + await super().update_id(new_id.lower()) - @classmethod - def all(cls) -> Iterable['PluginInstance']: - return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all()) - - def update_id(self, new_id: str) -> None: - if new_id is not None and new_id != self.id: - self.db_instance.id = new_id.lower() - - def update_config(self, config: str) -> None: - if not config or self.db_instance.config == config: + async def update_config(self, config: str | None) -> None: + if config is None or self.config_str == config: return - self.db_instance.config = config + self.config_str = config if self.started and self.plugin is not None: - self.plugin.on_external_config_update() + res = self.plugin.on_external_config_update() + if inspect.isawaitable(res): + await res + await self.update() - async def update_primary_user(self, primary_user: UserID) -> bool: - if not primary_user or primary_user == self.primary_user: + async def update_primary_user(self, primary_user: UserID | None) -> bool: + if primary_user is None or primary_user == self.primary_user: return True - client = Client.get(primary_user) + client = await Client.get(primary_user) if not client: return False await self.stop() - self.db_instance.primary_user = client.id + self.primary_user = client.id if self.client: self.client.references.remove(self) self.client = client self.client.references.add(self) + await self.update() await self.start() self.log.debug(f"Primary user switched to {self.client.id}") return True - async def update_type(self, type: str) -> bool: - if not type or type == self.type: + async def update_type(self, type: str | None) -> bool: + if type is None or type == self.type: return True try: loader = PluginLoader.find(type) except KeyError: return False await self.stop() - self.db_instance.type = loader.meta.id + self.type = loader.meta.id if self.loader: self.loader.references.remove(self) self.loader = loader self.loader.references.add(self) + await self.update() await self.start() self.log.debug(f"Type switched to {self.loader.meta.id}") return True @@ -268,38 +488,46 @@ class PluginInstance: if started is not None and started != self.started: await (self.start() if started else self.stop()) - def update_enabled(self, enabled: bool) -> None: + async def update_enabled(self, enabled: bool) -> None: if enabled is not None and enabled != self.enabled: - self.db_instance.enabled = enabled + self.enabled = enabled + await self.update() - # region Properties + async def update_db_engine(self, db_engine: DatabaseEngine | None) -> None: + if db_engine is not None and db_engine != self.database_engine: + self.database_engine = db_engine + await self.update() - @property - def id(self) -> str: - return self.db_instance.id + @classmethod + @async_getter_lock + async def get( + cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None + ) -> PluginInstance | None: + try: + return cls.cache[instance_id] + except KeyError: + pass - @id.setter - def id(self, value: str) -> None: - self.db_instance.id = value + instance = cast(cls, await super().get(instance_id)) + if instance is not None: + instance.postinit() + return instance - @property - def type(self) -> str: - return self.db_instance.type + if type and primary_user: + instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user) + await instance.insert() + instance.postinit() + return instance - @property - def enabled(self) -> bool: - return self.db_instance.enabled + return None - @property - def primary_user(self) -> UserID: - return self.db_instance.primary_user - - # endregion - - -def init(config: Config, webserver: 'MaubotServer', loop: AbstractEventLoop - ) -> Iterable[PluginInstance]: - PluginInstance.mb_config = config - PluginInstance.loop = loop - PluginInstance.webserver = webserver - return PluginInstance.all() + @classmethod + async def all(cls) -> AsyncGenerator[PluginInstance, None]: + instances = await super().all() + instance: PluginInstance + for instance in instances: + try: + yield cls.cache[instance.id] + except KeyError: + instance.postinit() + yield instance diff --git a/maubot/lib/color_log.py b/maubot/lib/color_log.py index 284cf74..8c36ed5 100644 --- a/maubot/lib/color_log.py +++ b/maubot/lib/color_log.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2020 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,8 +13,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from mautrix.util.logging.color import (ColorFormatter as BaseColorFormatter, PREFIX, MAU_COLOR, - MXID_COLOR, RESET) +from mautrix.util.logging.color import ( + MAU_COLOR, + MXID_COLOR, + PREFIX, + RESET, + ColorFormatter as BaseColorFormatter, +) INST_COLOR = PREFIX + "35m" # magenta LOADER_COLOR = PREFIX + "36m" # blue @@ -23,14 +28,22 @@ LOADER_COLOR = PREFIX + "36m" # blue class ColorFormatter(BaseColorFormatter): def _color_name(self, module: str) -> str: client = "maubot.client" - if module.startswith(client): - return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}" + if module.startswith(client + "."): + suffix = "" + if module.endswith(".crypto"): + suffix = f".{MAU_COLOR}crypto{RESET}" + module = module[: -len(".crypto")] + module = module[len(client) + 1 :] + return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}" instance = "maubot.instance" - if module.startswith(instance): + if module.startswith(instance + "."): return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}" + instance_db = "maubot.instance_db" + if module.startswith(instance_db + "."): + return f"{MAU_COLOR}{instance_db}{RESET}.{INST_COLOR}{module[len(instance_db) + 1:]}{RESET}" loader = "maubot.loader" - if module.startswith(loader): + if module.startswith(loader + "."): return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}" - if module.startswith("maubot"): + if module.startswith("maubot."): return f"{MAU_COLOR}{module}{RESET}" return super()._color_name(module) diff --git a/maubot/lib/future_awaitable.py b/maubot/lib/future_awaitable.py new file mode 100644 index 0000000..388eae9 --- /dev/null +++ b/maubot/lib/future_awaitable.py @@ -0,0 +1,9 @@ +from typing import Any, Awaitable, Callable, Generator + + +class FutureAwaitable: + def __init__(self, func: Callable[[], Awaitable[None]]) -> None: + self._func = func + + def __await__(self) -> Generator[Any, None, None]: + return self._func().__await__() diff --git a/maubot/lib/optionalalchemy.py b/maubot/lib/optionalalchemy.py new file mode 100644 index 0000000..ba94271 --- /dev/null +++ b/maubot/lib/optionalalchemy.py @@ -0,0 +1,19 @@ +try: + from sqlalchemy import MetaData, asc, create_engine, desc + from sqlalchemy.engine import Engine + from sqlalchemy.exc import IntegrityError, OperationalError +except ImportError: + + class FakeError(Exception): + pass + + class FakeType: + def __init__(self, *args, **kwargs): + raise Exception("SQLAlchemy is not installed") + + def create_engine(*args, **kwargs): + raise Exception("SQLAlchemy is not installed") + + MetaData = Engine = FakeType + IntegrityError = OperationalError = FakeError + asc = desc = lambda a: a diff --git a/maubot/lib/plugin_db.py b/maubot/lib/plugin_db.py new file mode 100644 index 0000000..977a619 --- /dev/null +++ b/maubot/lib/plugin_db.py @@ -0,0 +1,100 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from __future__ import annotations + +from contextlib import asynccontextmanager +import asyncio + +from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable +from mautrix.util.async_db.connection import LoggingConnection +from mautrix.util.logging import TraceLogger + +remove_double_quotes = str.maketrans({'"': "_"}) + + +class ProxyPostgresDatabase(Database): + scheme = Scheme.POSTGRES + _underlying_pool: PostgresDatabase + schema_name: str + _quoted_schema: str + _default_search_path: str + _conn_sema: asyncio.Semaphore + _max_conns: int + + def __init__( + self, + pool: PostgresDatabase, + instance_id: str, + max_conns: int, + upgrade_table: UpgradeTable | None, + log: TraceLogger | None = None, + ) -> None: + super().__init__(pool.url, upgrade_table=upgrade_table, log=log) + self._underlying_pool = pool + # Simple accidental SQL injection prevention. + # Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway. + self.schema_name = f"mbp_{instance_id.translate(remove_double_quotes)}" + self._quoted_schema = f'"{self.schema_name}"' + self._default_search_path = '"$user", public' + self._conn_sema = asyncio.BoundedSemaphore(max_conns) + self._max_conns = max_conns + + async def start(self) -> None: + async with self._underlying_pool.acquire() as conn: + self._default_search_path = await conn.fetchval("SHOW search_path") + self.log.trace(f"Found default search path: {self._default_search_path}") + await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._quoted_schema}") + await super().start() + + async def stop(self) -> None: + for _ in range(self._max_conns): + try: + await asyncio.wait_for(self._conn_sema.acquire(), timeout=3) + except asyncio.TimeoutError: + self.log.warning( + "Failed to drain plugin database connection pool, " + "the plugin may be leaking database connections" + ) + break + + async def delete(self) -> None: + self.log.info(f"Deleting schema {self.schema_name} and all data in it") + try: + await self._underlying_pool.execute( + f"DROP SCHEMA IF EXISTS {self._quoted_schema} CASCADE" + ) + except Exception: + self.log.warning("Failed to delete schema", exc_info=True) + + @asynccontextmanager + async def acquire(self) -> LoggingConnection: + conn: LoggingConnection + async with self._conn_sema, self._underlying_pool.acquire() as conn: + await conn.execute(f"SET search_path = {self._quoted_schema}") + try: + yield conn + finally: + if not conn.wrapped.is_closed(): + try: + await conn.execute(f"SET search_path = {self._default_search_path}") + except Exception: + self.log.exception("Error resetting search_path after use") + await conn.wrapped.close() + else: + self.log.debug("Connection was closed after use, not resetting search_path") + + +__all__ = ["ProxyPostgresDatabase"] diff --git a/maubot/lib/state_store.py b/maubot/lib/state_store.py new file mode 100644 index 0000000..81fb5fd --- /dev/null +++ b/maubot/lib/state_store.py @@ -0,0 +1,27 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore + +try: + from mautrix.crypto import StateStore as CryptoStateStore + + class PgStateStore(BasePgStateStore, CryptoStateStore): + pass + +except ImportError as e: + PgStateStore = BasePgStateStore + +__all__ = ["PgStateStore"] diff --git a/maubot/lib/zipimport.py b/maubot/lib/zipimport.py index f9a0ca7..e7b77db 100644 --- a/maubot/lib/zipimport.py +++ b/maubot/lib/zipimport.py @@ -18,26 +18,28 @@ used by the builtin import mechanism for sys.path items that are paths to Zip archives. """ -from importlib import _bootstrap_external from importlib import _bootstrap # for _verbose_message -import _imp # for check_hash_based_pycs -import _io # for open +from importlib import _bootstrap_external import marshal # for loads import sys # for modules import time # for mktime -__all__ = ['ZipImportError', 'zipimporter'] +import _imp # for check_hash_based_pycs +import _io # for open + +__all__ = ["ZipImportError", "zipimporter"] def _unpack_uint32(data): """Convert 4 bytes in little-endian to an integer.""" assert len(data) == 4 - return int.from_bytes(data, 'little') + return int.from_bytes(data, "little") + def _unpack_uint16(data): """Convert 2 bytes in little-endian to an integer.""" assert len(data) == 2 - return int.from_bytes(data, 'little') + return int.from_bytes(data, "little") path_sep = _bootstrap_external.path_sep @@ -47,15 +49,17 @@ alt_path_sep = _bootstrap_external.path_separators[1:] class ZipImportError(ImportError): pass + # _read_directory() cache _zip_directory_cache = {} _module_type = type(sys) END_CENTRAL_DIR_SIZE = 22 -STRING_END_ARCHIVE = b'PK\x05\x06' +STRING_END_ARCHIVE = b"PK\x05\x06" MAX_COMMENT_LEN = (1 << 16) - 1 + class zipimporter: """zipimporter(archivepath) -> zipimporter object @@ -77,9 +81,10 @@ class zipimporter: def __init__(self, path): if not isinstance(path, str): import os + path = os.fsdecode(path) if not path: - raise ZipImportError('archive path is empty', path=path) + raise ZipImportError("archive path is empty", path=path) if alt_path_sep: path = path.replace(alt_path_sep, path_sep) @@ -92,14 +97,14 @@ class zipimporter: # Back up one path element. dirname, basename = _bootstrap_external._path_split(path) if dirname == path: - raise ZipImportError('not a Zip file', path=path) + raise ZipImportError("not a Zip file", path=path) path = dirname prefix.append(basename) else: # it exists if (st.st_mode & 0o170000) != 0o100000: # stat.S_ISREG # it's a not file - raise ZipImportError('not a Zip file', path=path) + raise ZipImportError("not a Zip file", path=path) break try: @@ -154,11 +159,10 @@ class zipimporter: # This is possibly a portion of a namespace # package. Return the string representing its path, # without a trailing separator. - return None, [f'{self.archive}{path_sep}{modpath}'] + return None, [f"{self.archive}{path_sep}{modpath}"] return None, [] - # Check whether we can satisfy the import of the module named by # 'fullname'. Return self if we can, None if we can't. def find_module(self, fullname, path=None): @@ -172,7 +176,6 @@ class zipimporter: """ return self.find_loader(fullname, path)[0] - def get_code(self, fullname): """get_code(fullname) -> code object. @@ -182,7 +185,6 @@ class zipimporter: code, ispackage, modpath = _get_module_code(self, fullname) return code - def get_data(self, pathname): """get_data(pathname) -> string with file data. @@ -194,15 +196,14 @@ class zipimporter: key = pathname if pathname.startswith(self.archive + path_sep): - key = pathname[len(self.archive + path_sep):] + key = pathname[len(self.archive + path_sep) :] try: toc_entry = self._files[key] except KeyError: - raise OSError(0, '', key) + raise OSError(0, "", key) return _get_data(self.archive, toc_entry) - # Return a string matching __file__ for the named module def get_filename(self, fullname): """get_filename(fullname) -> filename string. @@ -214,7 +215,6 @@ class zipimporter: code, ispackage, modpath = _get_module_code(self, fullname) return modpath - def get_source(self, fullname): """get_source(fullname) -> source string. @@ -228,9 +228,9 @@ class zipimporter: path = _get_module_path(self, fullname) if mi: - fullpath = _bootstrap_external._path_join(path, '__init__.py') + fullpath = _bootstrap_external._path_join(path, "__init__.py") else: - fullpath = f'{path}.py' + fullpath = f"{path}.py" try: toc_entry = self._files[fullpath] @@ -239,7 +239,6 @@ class zipimporter: return None return _get_data(self.archive, toc_entry).decode() - # Return a bool signifying whether the module is a package or not. def is_package(self, fullname): """is_package(fullname) -> bool. @@ -252,7 +251,6 @@ class zipimporter: raise ZipImportError(f"can't find module {fullname!r}", name=fullname) return mi - # Load and return the module named by 'fullname'. def load_module(self, fullname): """load_module(fullname) -> module. @@ -276,7 +274,7 @@ class zipimporter: fullpath = _bootstrap_external._path_join(self.archive, path) mod.__path__ = [fullpath] - if not hasattr(mod, '__builtins__'): + if not hasattr(mod, "__builtins__"): mod.__builtins__ = __builtins__ _bootstrap_external._fix_up_module(mod.__dict__, fullname, modpath) exec(code, mod.__dict__) @@ -287,11 +285,10 @@ class zipimporter: try: mod = sys.modules[fullname] except KeyError: - raise ImportError(f'Loaded module {fullname!r} not found in sys.modules') - _bootstrap._verbose_message('import {} # loaded from Zip {}', fullname, modpath) + raise ImportError(f"Loaded module {fullname!r} not found in sys.modules") + _bootstrap._verbose_message("import {} # loaded from Zip {}", fullname, modpath) return mod - def get_resource_reader(self, fullname): """Return the ResourceReader for a package in a zip file. @@ -305,11 +302,11 @@ class zipimporter: return None if not _ZipImportResourceReader._registered: from importlib.abc import ResourceReader + ResourceReader.register(_ZipImportResourceReader) _ZipImportResourceReader._registered = True return _ZipImportResourceReader(self, fullname) - def __repr__(self): return f'' @@ -320,16 +317,18 @@ class zipimporter: # are swapped by initzipimport() if we run in optimized mode. Also, # '/' is replaced by path_sep there. _zip_searchorder = ( - (path_sep + '__init__.pyc', True, True), - (path_sep + '__init__.py', False, True), - ('.pyc', True, False), - ('.py', False, False), + (path_sep + "__init__.pyc", True, True), + (path_sep + "__init__.py", False, True), + (".pyc", True, False), + (".py", False, False), ) + # Given a module name, return the potential file path in the # archive (without extension). def _get_module_path(self, fullname): - return self.prefix + fullname.rpartition('.')[2] + return self.prefix + fullname.rpartition(".")[2] + # Does this path represent a directory? def _is_dir(self, path): @@ -340,6 +339,7 @@ def _is_dir(self, path): # If dirpath is present in self._files, we have a directory. return dirpath in self._files + # Return some information about a module. def _get_module_info(self, fullname): path = _get_module_path(self, fullname) @@ -352,6 +352,7 @@ def _get_module_info(self, fullname): # implementation + # _read_directory(archive) -> files dict (new reference) # # Given a path to a Zip archive, build a dict, mapping file names @@ -374,7 +375,7 @@ def _get_module_info(self, fullname): # data_size and file_offset are 0. def _read_directory(archive): try: - fp = _io.open(archive, 'rb') + fp = _io.open(archive, "rb") except OSError: raise ZipImportError(f"can't open Zip file: {archive!r}", path=archive) @@ -394,36 +395,33 @@ def _read_directory(archive): fp.seek(0, 2) file_size = fp.tell() except OSError: - raise ZipImportError(f"can't read Zip file: {archive!r}", - path=archive) - max_comment_start = max(file_size - MAX_COMMENT_LEN - - END_CENTRAL_DIR_SIZE, 0) + raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) + max_comment_start = max(file_size - MAX_COMMENT_LEN - END_CENTRAL_DIR_SIZE, 0) try: fp.seek(max_comment_start) data = fp.read() except OSError: - raise ZipImportError(f"can't read Zip file: {archive!r}", - path=archive) + raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) pos = data.rfind(STRING_END_ARCHIVE) if pos < 0: - raise ZipImportError(f'not a Zip file: {archive!r}', - path=archive) - buffer = data[pos:pos+END_CENTRAL_DIR_SIZE] + raise ZipImportError(f"not a Zip file: {archive!r}", path=archive) + buffer = data[pos : pos + END_CENTRAL_DIR_SIZE] if len(buffer) != END_CENTRAL_DIR_SIZE: - raise ZipImportError(f"corrupt Zip file: {archive!r}", - path=archive) + raise ZipImportError(f"corrupt Zip file: {archive!r}", path=archive) header_position = file_size - len(data) + pos header_size = _unpack_uint32(buffer[12:16]) header_offset = _unpack_uint32(buffer[16:20]) if header_position < header_size: - raise ZipImportError(f'bad central directory size: {archive!r}', path=archive) + raise ZipImportError(f"bad central directory size: {archive!r}", path=archive) if header_position < header_offset: - raise ZipImportError(f'bad central directory offset: {archive!r}', path=archive) + raise ZipImportError(f"bad central directory offset: {archive!r}", path=archive) header_position -= header_size arc_offset = header_position - header_offset if arc_offset < 0: - raise ZipImportError(f'bad central directory size or offset: {archive!r}', path=archive) + raise ZipImportError( + f"bad central directory size or offset: {archive!r}", path=archive + ) files = {} # Start of Central Directory @@ -435,12 +433,12 @@ def _read_directory(archive): while True: buffer = fp.read(46) if len(buffer) < 4: - raise EOFError('EOF read where not expected') + raise EOFError("EOF read where not expected") # Start of file header - if buffer[:4] != b'PK\x01\x02': - break # Bad: Central Dir File Header + if buffer[:4] != b"PK\x01\x02": + break # Bad: Central Dir File Header if len(buffer) != 46: - raise EOFError('EOF read where not expected') + raise EOFError("EOF read where not expected") flags = _unpack_uint16(buffer[8:10]) compress = _unpack_uint16(buffer[10:12]) time = _unpack_uint16(buffer[12:14]) @@ -454,7 +452,7 @@ def _read_directory(archive): file_offset = _unpack_uint32(buffer[42:46]) header_size = name_size + extra_size + comment_size if file_offset > header_offset: - raise ZipImportError(f'bad local header offset: {archive!r}', path=archive) + raise ZipImportError(f"bad local header offset: {archive!r}", path=archive) file_offset += arc_offset try: @@ -478,18 +476,19 @@ def _read_directory(archive): else: # Historical ZIP filename encoding try: - name = name.decode('ascii') + name = name.decode("ascii") except UnicodeDecodeError: - name = name.decode('latin1').translate(cp437_table) + name = name.decode("latin1").translate(cp437_table) - name = name.replace('/', path_sep) + name = name.replace("/", path_sep) path = _bootstrap_external._path_join(archive, name) t = (path, compress, data_size, file_size, file_offset, time, date, crc) files[name] = t count += 1 - _bootstrap._verbose_message('zipimport: found {} names in {!r}', count, archive) + _bootstrap._verbose_message("zipimport: found {} names in {!r}", count, archive) return files + # During bootstrap, we may need to load the encodings # package from a ZIP file. But the cp437 encoding is implemented # in Python in the encodings package. @@ -498,35 +497,36 @@ def _read_directory(archive): # the cp437 encoding. cp437_table = ( # ASCII part, 8 rows x 16 chars - '\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f' - '\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f' - ' !"#$%&\'()*+,-./' - '0123456789:;<=>?' - '@ABCDEFGHIJKLMNO' - 'PQRSTUVWXYZ[\\]^_' - '`abcdefghijklmno' - 'pqrstuvwxyz{|}~\x7f' + "\x00\x01\x02\x03\x04\x05\x06\x07\x08\t\n\x0b\x0c\r\x0e\x0f" + "\x10\x11\x12\x13\x14\x15\x16\x17\x18\x19\x1a\x1b\x1c\x1d\x1e\x1f" + " !\"#$%&'()*+,-./" + "0123456789:;<=>?" + "@ABCDEFGHIJKLMNO" + "PQRSTUVWXYZ[\\]^_" + "`abcdefghijklmno" + "pqrstuvwxyz{|}~\x7f" # non-ASCII part, 16 rows x 8 chars - '\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7' - '\xea\xeb\xe8\xef\xee\xec\xc4\xc5' - '\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9' - '\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192' - '\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba' - '\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb' - '\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556' - '\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510' - '\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f' - '\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567' - '\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b' - '\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580' - '\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4' - '\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229' - '\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248' - '\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0' + "\xc7\xfc\xe9\xe2\xe4\xe0\xe5\xe7" + "\xea\xeb\xe8\xef\xee\xec\xc4\xc5" + "\xc9\xe6\xc6\xf4\xf6\xf2\xfb\xf9" + "\xff\xd6\xdc\xa2\xa3\xa5\u20a7\u0192" + "\xe1\xed\xf3\xfa\xf1\xd1\xaa\xba" + "\xbf\u2310\xac\xbd\xbc\xa1\xab\xbb" + "\u2591\u2592\u2593\u2502\u2524\u2561\u2562\u2556" + "\u2555\u2563\u2551\u2557\u255d\u255c\u255b\u2510" + "\u2514\u2534\u252c\u251c\u2500\u253c\u255e\u255f" + "\u255a\u2554\u2569\u2566\u2560\u2550\u256c\u2567" + "\u2568\u2564\u2565\u2559\u2558\u2552\u2553\u256b" + "\u256a\u2518\u250c\u2588\u2584\u258c\u2590\u2580" + "\u03b1\xdf\u0393\u03c0\u03a3\u03c3\xb5\u03c4" + "\u03a6\u0398\u03a9\u03b4\u221e\u03c6\u03b5\u2229" + "\u2261\xb1\u2265\u2264\u2320\u2321\xf7\u2248" + "\xb0\u2219\xb7\u221a\u207f\xb2\u25a0\xa0" ) _importing_zlib = False + # Return the zlib.decompress function object, or NULL if zlib couldn't # be imported. The function is cached when found, so subsequent calls # don't import zlib again. @@ -535,28 +535,29 @@ def _get_decompress_func(): if _importing_zlib: # Someone has a zlib.py[co] in their Zip file # let's avoid a stack overflow. - _bootstrap._verbose_message('zipimport: zlib UNAVAILABLE') + _bootstrap._verbose_message("zipimport: zlib UNAVAILABLE") raise ZipImportError("can't decompress data; zlib not available") _importing_zlib = True try: from zlib import decompress except Exception: - _bootstrap._verbose_message('zipimport: zlib UNAVAILABLE') + _bootstrap._verbose_message("zipimport: zlib UNAVAILABLE") raise ZipImportError("can't decompress data; zlib not available") finally: _importing_zlib = False - _bootstrap._verbose_message('zipimport: zlib available') + _bootstrap._verbose_message("zipimport: zlib available") return decompress + # Given a path to a Zip file and a toc_entry, return the (uncompressed) data. def _get_data(archive, toc_entry): datapath, compress, data_size, file_size, file_offset, time, date, crc = toc_entry if data_size < 0: - raise ZipImportError('negative data size') + raise ZipImportError("negative data size") - with _io.open(archive, 'rb') as fp: + with _io.open(archive, "rb") as fp: # Check to make sure the local file header is correct try: fp.seek(file_offset) @@ -564,11 +565,11 @@ def _get_data(archive, toc_entry): raise ZipImportError(f"can't read Zip file: {archive!r}", path=archive) buffer = fp.read(30) if len(buffer) != 30: - raise EOFError('EOF read where not expected') + raise EOFError("EOF read where not expected") - if buffer[:4] != b'PK\x03\x04': + if buffer[:4] != b"PK\x03\x04": # Bad: Local File Header - raise ZipImportError(f'bad local file header: {archive!r}', path=archive) + raise ZipImportError(f"bad local file header: {archive!r}", path=archive) name_size = _unpack_uint16(buffer[26:28]) extra_size = _unpack_uint16(buffer[28:30]) @@ -601,16 +602,17 @@ def _eq_mtime(t1, t2): # dostime only stores even seconds, so be lenient return abs(t1 - t2) <= 1 + # Given the contents of a .py[co] file, unmarshal the data # and return the code object. Return None if it the magic word doesn't # match (we do this instead of raising an exception as we fall back # to .py if available and we don't want to mask other errors). def _unmarshal_code(pathname, data, mtime): if len(data) < 16: - raise ZipImportError('bad pyc data') + raise ZipImportError("bad pyc data") if data[:4] != _bootstrap_external.MAGIC_NUMBER: - _bootstrap._verbose_message('{!r} has bad magic', pathname) + _bootstrap._verbose_message("{!r} has bad magic", pathname) return None # signal caller to try alternative flags = _unpack_uint32(data[4:8]) @@ -619,47 +621,57 @@ def _unmarshal_code(pathname, data, mtime): # pycs. We could validate hash-based pycs against the source, but it # seems likely that most people putting hash-based pycs in a zipfile # will use unchecked ones. - if (_imp.check_hash_based_pycs != 'never' and - (flags != 0x1 or _imp.check_hash_based_pycs == 'always')): + if _imp.check_hash_based_pycs != "never" and ( + flags != 0x1 or _imp.check_hash_based_pycs == "always" + ): return None elif mtime != 0 and not _eq_mtime(_unpack_uint32(data[8:12]), mtime): - _bootstrap._verbose_message('{!r} has bad mtime', pathname) + _bootstrap._verbose_message("{!r} has bad mtime", pathname) return None # signal caller to try alternative # XXX the pyc's size field is ignored; timestamp collisions are probably # unimportant with zip files. code = marshal.loads(data[16:]) if not isinstance(code, _code_type): - raise TypeError(f'compiled module {pathname!r} is not a code object') + raise TypeError(f"compiled module {pathname!r} is not a code object") return code + _code_type = type(_unmarshal_code.__code__) # Replace any occurrences of '\r\n?' in the input string with '\n'. # This converts DOS and Mac line endings to Unix line endings. def _normalize_line_endings(source): - source = source.replace(b'\r\n', b'\n') - source = source.replace(b'\r', b'\n') + source = source.replace(b"\r\n", b"\n") + source = source.replace(b"\r", b"\n") return source + # Given a string buffer containing Python source code, compile it # and return a code object. def _compile_source(pathname, source): source = _normalize_line_endings(source) - return compile(source, pathname, 'exec', dont_inherit=True) + return compile(source, pathname, "exec", dont_inherit=True) + # Convert the date/time values found in the Zip archive to a value # that's compatible with the time stamp stored in .pyc files. def _parse_dostime(d, t): - return time.mktime(( - (d >> 9) + 1980, # bits 9..15: year - (d >> 5) & 0xF, # bits 5..8: month - d & 0x1F, # bits 0..4: day - t >> 11, # bits 11..15: hours - (t >> 5) & 0x3F, # bits 8..10: minutes - (t & 0x1F) * 2, # bits 0..7: seconds / 2 - -1, -1, -1)) + return time.mktime( + ( + (d >> 9) + 1980, # bits 9..15: year + (d >> 5) & 0xF, # bits 5..8: month + d & 0x1F, # bits 0..4: day + t >> 11, # bits 11..15: hours + (t >> 5) & 0x3F, # bits 8..10: minutes + (t & 0x1F) * 2, # bits 0..7: seconds / 2 + -1, + -1, + -1, + ) + ) + # Given a path to a .pyc file in the archive, return the # modification time of the matching .py file, or 0 if no source @@ -667,7 +679,7 @@ def _parse_dostime(d, t): def _get_mtime_of_source(self, path): try: # strip 'c' or 'o' from *.py[co] - assert path[-1:] in ('c', 'o') + assert path[-1:] in ("c", "o") path = path[:-1] toc_entry = self._files[path] # fetch the time stamp of the .py file for comparison @@ -678,13 +690,14 @@ def _get_mtime_of_source(self, path): except (KeyError, IndexError, TypeError): return 0 + # Get the code object associated with the module specified by # 'fullname'. def _get_module_code(self, fullname): path = _get_module_path(self, fullname) for suffix, isbytecode, ispackage in _zip_searchorder: fullpath = path + suffix - _bootstrap._verbose_message('trying {}{}{}', self.archive, path_sep, fullpath, verbosity=2) + _bootstrap._verbose_message("trying {}{}{}", self.archive, path_sep, fullpath, verbosity=2) try: toc_entry = self._files[fullpath] except KeyError: @@ -713,6 +726,7 @@ class _ZipImportResourceReader: This class is allowed to reference all the innards and private parts of the zipimporter. """ + _registered = False def __init__(self, zipimporter, fullname): @@ -720,9 +734,10 @@ class _ZipImportResourceReader: self.fullname = fullname def open_resource(self, resource): - fullname_as_path = self.fullname.replace('.', '/') - path = f'{fullname_as_path}/{resource}' + fullname_as_path = self.fullname.replace(".", "/") + path = f"{fullname_as_path}/{resource}" from io import BytesIO + try: return BytesIO(self.zipimporter.get_data(path)) except OSError: @@ -737,8 +752,8 @@ class _ZipImportResourceReader: def is_resource(self, name): # Maybe we could do better, but if we can get the data, it's a # resource. Otherwise it isn't. - fullname_as_path = self.fullname.replace('.', '/') - path = f'{fullname_as_path}/{name}' + fullname_as_path = self.fullname.replace(".", "/") + path = f"{fullname_as_path}/{name}" try: self.zipimporter.get_data(path) except OSError: @@ -754,11 +769,12 @@ class _ZipImportResourceReader: # top of the archive, and then we iterate through _files looking for # names inside that "directory". from pathlib import Path + fullname_path = Path(self.zipimporter.get_filename(self.fullname)) relative_path = fullname_path.relative_to(self.zipimporter.archive) # Don't forget that fullname names a package, so its path will include # __init__.py, which we want to ignore. - assert relative_path.name == '__init__.py' + assert relative_path.name == "__init__.py" package_path = relative_path.parent subdirs_seen = set() for filename in self.zipimporter._files: diff --git a/maubot/loader/__init__.py b/maubot/loader/__init__.py index ec298ad..d61be5c 100644 --- a/maubot/loader/__init__.py +++ b/maubot/loader/__init__.py @@ -1,2 +1,3 @@ -from .abc import PluginLoader, PluginClass, IDConflictError, PluginMeta -from .zip import ZippedPluginLoader, MaubotZipImportError +from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader +from .meta import DatabaseType, PluginMeta +from .zip import MaubotZipImportError, ZippedPluginLoader diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py index 01713f4..c2c71b2 100644 --- a/maubot/loader/abc.py +++ b/maubot/loader/abc.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,17 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import TypeVar, Type, Dict, Set, List, TYPE_CHECKING +from __future__ import annotations + +from typing import TYPE_CHECKING, TypeVar from abc import ABC, abstractmethod import asyncio -from attr import dataclass -from packaging.version import Version, InvalidVersion - -from mautrix.types import SerializableAttrs, SerializerError, serializer, deserializer - -from ..__meta__ import __version__ from ..plugin_base import Plugin +from .meta import PluginMeta if TYPE_CHECKING: from ..instance import PluginInstance @@ -35,47 +32,40 @@ class IDConflictError(Exception): pass -@serializer(Version) -def serialize_version(version: Version) -> str: - return str(version) +class BasePluginLoader(ABC): + meta: PluginMeta + + @property + @abstractmethod + def source(self) -> str: + pass + + def sync_read_file(self, path: str) -> bytes: + raise NotImplementedError("This loader doesn't support synchronous operations") + + @abstractmethod + async def read_file(self, path: str) -> bytes: + pass + + def sync_list_files(self, directory: str) -> list[str]: + raise NotImplementedError("This loader doesn't support synchronous operations") + + @abstractmethod + async def list_files(self, directory: str) -> list[str]: + pass -@deserializer(Version) -def deserialize_version(version: str) -> Version: - try: - return Version(version) - except InvalidVersion as e: - raise SerializerError("Invalid version") from e - - -@dataclass -class PluginMeta(SerializableAttrs['PluginMeta']): - id: str - version: Version - modules: List[str] - main_class: str - - maubot: Version = Version(__version__) - database: bool = False - config: bool = False - webapp: bool = False - license: str = "" - extra_files: List[str] = [] - dependencies: List[str] = [] - soft_dependencies: List[str] = [] - - -class PluginLoader(ABC): - id_cache: Dict[str, 'PluginLoader'] = {} +class PluginLoader(BasePluginLoader, ABC): + id_cache: dict[str, PluginLoader] = {} meta: PluginMeta - references: Set['PluginInstance'] + references: set[PluginInstance] def __init__(self): self.references = set() @classmethod - def find(cls, plugin_id: str) -> 'PluginLoader': + def find(cls, plugin_id: str) -> PluginLoader: return cls.id_cache[plugin_id] def to_dict(self) -> dict: @@ -85,33 +75,22 @@ class PluginLoader(ABC): "instances": [instance.to_dict() for instance in self.references], } - @property - @abstractmethod - def source(self) -> str: - pass - - @abstractmethod - async def read_file(self, path: str) -> bytes: - pass - async def stop_instances(self) -> None: - await asyncio.gather(*[instance.stop() for instance - in self.references if instance.started]) + await asyncio.gather( + *[instance.stop() for instance in self.references if instance.started] + ) async def start_instances(self) -> None: - await asyncio.gather(*[instance.start() for instance - in self.references if instance.enabled]) + await asyncio.gather( + *[instance.start() for instance in self.references if instance.enabled] + ) @abstractmethod - async def load(self) -> Type[PluginClass]: + async def load(self) -> type[PluginClass]: pass @abstractmethod - async def reload(self) -> Type[PluginClass]: - pass - - @abstractmethod - async def unload(self) -> None: + async def reload(self) -> type[PluginClass]: pass @abstractmethod diff --git a/maubot/loader/meta.py b/maubot/loader/meta.py new file mode 100644 index 0000000..d368e24 --- /dev/null +++ b/maubot/loader/meta.py @@ -0,0 +1,69 @@ +# maubot - A plugin-based Matrix bot system. +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import List, Optional + +from attr import dataclass +from packaging.version import InvalidVersion, Version + +from mautrix.types import ( + ExtensibleEnum, + SerializableAttrs, + SerializerError, + deserializer, + serializer, +) + +from ..__meta__ import __version__ + + +@serializer(Version) +def serialize_version(version: Version) -> str: + return str(version) + + +@deserializer(Version) +def deserialize_version(version: str) -> Version: + try: + return Version(version) + except InvalidVersion as e: + raise SerializerError("Invalid version") from e + + +class DatabaseType(ExtensibleEnum): + SQLALCHEMY = "sqlalchemy" + ASYNCPG = "asyncpg" + + +@dataclass +class PluginMeta(SerializableAttrs): + id: str + version: Version + modules: List[str] + main_class: str + + maubot: Version = Version(__version__) + database: bool = False + database_type: DatabaseType = DatabaseType.SQLALCHEMY + config: bool = False + webapp: bool = False + license: str = "" + extra_files: List[str] = [] + dependencies: List[str] = [] + soft_dependencies: List[str] = [] + + @property + def database_type_str(self) -> Optional[str]: + return self.database_type.value if self.database else None diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py index 72daf22..8642183 100644 --- a/maubot/loader/zip.py +++ b/maubot/loader/zip.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,23 +13,27 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, List, Type, Tuple, Optional -from zipfile import ZipFile, BadZipFile -from time import time -import logging -import sys -import os +from __future__ import annotations + +from time import time +from zipfile import BadZipFile, ZipFile +import logging +import os +import sys -from ruamel.yaml import YAML, YAMLError from packaging.version import Version +from ruamel.yaml import YAML, YAMLError from mautrix.types import SerializerError -from ..lib.zipimport import zipimporter, ZipImportError -from ..plugin_base import Plugin +from ..__meta__ import __version__ from ..config import Config -from .abc import PluginLoader, PluginClass, PluginMeta, IDConflictError +from ..lib.zipimport import ZipImportError, zipimporter +from ..plugin_base import Plugin +from .abc import IDConflictError, PluginClass, PluginLoader +from .meta import DatabaseType, PluginMeta +current_version = Version(__version__) yaml = YAML() @@ -50,23 +54,25 @@ class MaubotZipLoadError(MaubotZipImportError): class ZippedPluginLoader(PluginLoader): - path_cache: Dict[str, 'ZippedPluginLoader'] = {} + path_cache: dict[str, ZippedPluginLoader] = {} log: logging.Logger = logging.getLogger("maubot.loader.zip") trash_path: str = "delete" - directories: List[str] = [] + directories: list[str] = [] - path: str - meta: PluginMeta - main_class: str - main_module: str - _loaded: Type[PluginClass] - _importer: zipimporter - _file: ZipFile + path: str | None + meta: PluginMeta | None + main_class: str | None + main_module: str | None + _loaded: type[PluginClass] | None + _importer: zipimporter | None + _file: ZipFile | None def __init__(self, path: str) -> None: super().__init__() self.path = path self.meta = None + self.main_class = None + self.main_module = None self._loaded = None self._importer = None self._file = None @@ -75,7 +81,8 @@ class ZippedPluginLoader(PluginLoader): try: existing = self.id_cache[self.meta.id] raise IDConflictError( - f"Plugin with id {self.meta.id} already loaded from {existing.source}") + f"Plugin with id {self.meta.id} already loaded from {existing.source}" + ) except KeyError: pass self.path_cache[self.path] = self @@ -83,13 +90,10 @@ class ZippedPluginLoader(PluginLoader): self.log.debug(f"Preloaded plugin {self.meta.id} from {self.path}") def to_dict(self) -> dict: - return { - **super().to_dict(), - "path": self.path - } + return {**super().to_dict(), "path": self.path} @classmethod - def get(cls, path: str) -> 'ZippedPluginLoader': + def get(cls, path: str) -> ZippedPluginLoader: path = os.path.abspath(path) try: return cls.path_cache[path] @@ -101,16 +105,32 @@ class ZippedPluginLoader(PluginLoader): return self.path def __repr__(self) -> str: - return ("") + return ( + "" + ) - async def read_file(self, path: str) -> bytes: + def sync_read_file(self, path: str) -> bytes: return self._file.read(path) + async def read_file(self, path: str) -> bytes: + return self.sync_read_file(path) + + def sync_list_files(self, directory: str) -> list[str]: + directory = directory.rstrip("/") + return [ + file.filename + for file in self._file.filelist + if os.path.dirname(file.filename) == directory + ] + + async def list_files(self, directory: str) -> list[str]: + return self.sync_list_files(directory) + @staticmethod - def _read_meta(source) -> Tuple[ZipFile, PluginMeta]: + def _read_meta(source) -> tuple[ZipFile, PluginMeta]: try: file = ZipFile(source) data = file.read("maubot.yaml") @@ -128,12 +148,16 @@ class ZippedPluginLoader(PluginLoader): meta = PluginMeta.deserialize(meta_dict) except SerializerError as e: raise MaubotZipMetaError("Maubot plugin definition in file is invalid") from e + if meta.maubot > current_version: + raise MaubotZipMetaError( + f"Plugin requires maubot {meta.maubot}, but this instance is {current_version}" + ) return file, meta @classmethod - def verify_meta(cls, source) -> Tuple[str, Version]: + def verify_meta(cls, source) -> tuple[str, Version, DatabaseType | None]: _, meta = cls._read_meta(source) - return meta.id, meta.version + return meta.id, meta.version, meta.database_type if meta.database else None def _load_meta(self) -> None: file, meta = self._read_meta(self.path) @@ -143,7 +167,7 @@ class ZippedPluginLoader(PluginLoader): if "/" in meta.main_class: self.main_module, self.main_class = meta.main_class.split("/")[:2] else: - self.main_module = meta.modules[0] + self.main_module = meta.modules[-1] self.main_class = meta.main_class self._file = file @@ -162,24 +186,24 @@ class ZippedPluginLoader(PluginLoader): code = importer.get_code(self.main_module.replace(".", "/")) if self.main_class not in code.co_names: raise MaubotZipPreLoadError( - f"Main class {self.main_class} not in {self.main_module}") + f"Main class {self.main_class} not in {self.main_module}" + ) except ZipImportError as e: - raise MaubotZipPreLoadError( - f"Main module {self.main_module} not found in file") from e + raise MaubotZipPreLoadError(f"Main module {self.main_module} not found in file") from e for module in self.meta.modules: try: importer.find_module(module) except ZipImportError as e: raise MaubotZipPreLoadError(f"Module {module} not found in file") from e - async def load(self, reset_cache: bool = False) -> Type[PluginClass]: + async def load(self, reset_cache: bool = False) -> type[PluginClass]: try: return self._load(reset_cache) except MaubotZipImportError: self.log.exception(f"Failed to load {self.meta.id} v{self.meta.version}") raise - def _load(self, reset_cache: bool = False) -> Type[PluginClass]: + def _load(self, reset_cache: bool = False) -> type[PluginClass]: if self._loaded is not None and not reset_cache: return self._loaded self._load_meta() @@ -208,13 +232,18 @@ class ZippedPluginLoader(PluginLoader): self.log.debug(f"Loaded and imported plugin {self.meta.id} from {self.path}") return plugin - async def reload(self, new_path: Optional[str] = None) -> Type[PluginClass]: - await self.unload() - if new_path is not None: + async def reload(self, new_path: str | None = None) -> type[PluginClass]: + self._unload() + if new_path is not None and new_path != self.path: + try: + del self.path_cache[self.path] + except KeyError: + pass self.path = new_path + self.path_cache[self.path] = self return await self.load(reset_cache=True) - async def unload(self) -> None: + def _unload(self) -> None: for name, mod in list(sys.modules.items()): if (getattr(mod, "__file__", "") or "").startswith(self.path): del sys.modules[name] @@ -222,7 +251,7 @@ class ZippedPluginLoader(PluginLoader): self.log.debug(f"Unloaded plugin {self.meta.id} at {self.path}") async def delete(self) -> None: - await self.unload() + self._unload() try: del self.path_cache[self.path] except KeyError: @@ -240,12 +269,22 @@ class ZippedPluginLoader(PluginLoader): self.path = None @classmethod - def trash(cls, file_path: str, new_name: Optional[str] = None, reason: str = "error") -> None: + def trash(cls, file_path: str, new_name: str | None = None, reason: str = "error") -> None: if cls.trash_path == "delete": - os.remove(file_path) + try: + os.remove(file_path) + except FileNotFoundError: + pass else: new_name = new_name or f"{int(time())}-{reason}-{os.path.basename(file_path)}" - os.rename(file_path, os.path.abspath(os.path.join(cls.trash_path, new_name))) + try: + os.rename(file_path, os.path.abspath(os.path.join(cls.trash_path, new_name))) + except OSError as e: + cls.log.warning(f"Failed to rename {file_path}: {e} - trying to delete") + try: + os.remove(file_path) + except FileNotFoundError: + pass @classmethod def load_all(cls): diff --git a/maubot/management/api/__init__.py b/maubot/management/api/__init__.py index 5326039..c2e5f24 100644 --- a/maubot/management/api/__init__.py +++ b/maubot/management/api/__init__.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,13 +13,14 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from aiohttp import web from asyncio import AbstractEventLoop import importlib +from aiohttp import web + from ...config import Config -from .base import routes, get_config, set_config, set_loop from .auth import check_token +from .base import get_config, routes, set_config from .middleware import auth, error @@ -30,14 +31,15 @@ def features(request: web.Request) -> web.Response: if err is None: return web.json_response(data) else: - return web.json_response({ - "login": data["login"], - }) + return web.json_response( + { + "login": data["login"], + } + ) def init(cfg: Config, loop: AbstractEventLoop) -> web.Application: set_config(cfg) - set_loop(loop) for pkg, enabled in cfg["api_features"].items(): if enabled: importlib.import_module(f"maubot.management.api.{pkg}") diff --git a/maubot/management/api/auth.py b/maubot/management/api/auth.py index 4675301..0abc3ad 100644 --- a/maubot/management/api/auth.py +++ b/maubot/management/api/auth.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,7 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional +from __future__ import annotations + from time import time from aiohttp import web @@ -21,7 +22,7 @@ from aiohttp import web from mautrix.types import UserID from mautrix.util.signed_token import sign_token, verify_token -from .base import routes, get_config +from .base import get_config, routes from .responses import resp @@ -33,22 +34,25 @@ def is_valid_token(token: str) -> bool: def create_token(user: UserID) -> str: - return sign_token(get_config()["server.unshared_secret"], { - "user_id": user, - "created_at": int(time()), - }) + return sign_token( + get_config()["server.unshared_secret"], + { + "user_id": user, + "created_at": int(time()), + }, + ) def get_token(request: web.Request) -> str: token = request.headers.get("Authorization", "") if not token or not token.startswith("Bearer "): - token = request.query.get("access_token", None) + token = request.query.get("access_token", "") else: - token = token[len("Bearer "):] + token = token[len("Bearer ") :] return token -def check_token(request: web.Request) -> Optional[web.Response]: +def check_token(request: web.Request) -> web.Response | None: token = get_token(request) if not token: return resp.no_token diff --git a/maubot/management/api/base.py b/maubot/management/api/base.py index b6a5dea..3d7693a 100644 --- a/maubot/management/api/base.py +++ b/maubot/management/api/base.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,15 +13,17 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from aiohttp import web +from __future__ import annotations + import asyncio +from aiohttp import web + from ...__meta__ import __version__ from ...config import Config routes: web.RouteTableDef = web.RouteTableDef() -_config: Config = None -_loop: asyncio.AbstractEventLoop = None +_config: Config | None = None def set_config(config: Config) -> None: @@ -33,17 +35,6 @@ def get_config() -> Config: return _config -def set_loop(loop: asyncio.AbstractEventLoop) -> None: - global _loop - _loop = loop - - -def get_loop() -> asyncio.AbstractEventLoop: - return _loop - - @routes.get("/version") async def version(_: web.Request) -> web.Response: - return web.json_response({ - "version": __version__ - }) + return web.json_response({"version": __version__}) diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py index 0585d63..d2ad35d 100644 --- a/maubot/management/api/client.py +++ b/maubot/management/api/client.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,20 +13,23 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional +from __future__ import annotations + from json import JSONDecodeError +import logging from aiohttp import web -from mautrix.types import UserID, SyncToken, FilterID -from mautrix.errors import MatrixRequestError, MatrixConnectionError, MatrixInvalidToken from mautrix.client import Client as MatrixClient +from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequestError +from mautrix.types import FilterID, SyncToken, UserID -from ...db import DBClient from ...client import Client from .base import routes from .responses import resp +log = logging.getLogger("maubot.server.client") + @routes.get("/clients") async def get_clients(_: web.Request) -> web.Response: @@ -36,64 +39,94 @@ async def get_clients(_: web.Request) -> web.Response: @routes.get("/client/{id}") async def get_client(request: web.Request) -> web.Response: user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + client = await Client.get(user_id) if not client: return resp.client_not_found return resp.found(client.to_dict()) -async def _create_client(user_id: Optional[UserID], data: dict) -> web.Response: +async def _create_client(user_id: UserID | None, data: dict) -> web.Response: homeserver = data.get("homeserver", None) access_token = data.get("access_token", None) - new_client = MatrixClient(mxid="@not:a.mxid", base_url=homeserver, token=access_token, - loop=Client.loop, client_session=Client.http_client) + device_id = data.get("device_id", None) + new_client = MatrixClient( + mxid="@not:a.mxid", + base_url=homeserver, + token=access_token, + client_session=Client.http_client, + ) try: - mxid = await new_client.whoami() - except MatrixInvalidToken: + whoami = await new_client.whoami() + except MatrixInvalidToken as e: return resp.bad_client_access_token except MatrixRequestError: + log.warning(f"Failed to get whoami from {homeserver} for new client", exc_info=True) return resp.bad_client_access_details except MatrixConnectionError: + log.warning(f"Failed to connect to {homeserver} for new client", exc_info=True) return resp.bad_client_connection_details if user_id is None: - existing_client = Client.get(mxid, None) + existing_client = await Client.get(whoami.user_id) if existing_client is not None: return resp.user_exists - elif mxid != user_id: - return resp.mxid_mismatch(mxid) - db_instance = DBClient(id=mxid, homeserver=homeserver, access_token=access_token, - enabled=data.get("enabled", True), next_batch=SyncToken(""), - filter_id=FilterID(""), sync=data.get("sync", True), - autojoin=data.get("autojoin", True), online=data.get("online", True), - displayname=data.get("displayname", ""), - avatar_url=data.get("avatar_url", "")) - client = Client(db_instance) - client.db_instance.insert() + elif whoami.user_id != user_id: + return resp.mxid_mismatch(whoami.user_id) + elif whoami.device_id and device_id and whoami.device_id != device_id: + return resp.device_id_mismatch(whoami.device_id) + client = await Client.get( + whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id + ) + client.enabled = data.get("enabled", True) + client.sync = data.get("sync", True) + await client.update_autojoin(data.get("autojoin", True), save=False) + await client.update_online(data.get("online", True), save=False) + client.displayname = data.get("displayname", "disable") + client.avatar_url = data.get("avatar_url", "disable") + await client.update() await client.start() return resp.created(client.to_dict()) -async def _update_client(client: Client, data: dict) -> web.Response: +async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response: try: - await client.update_access_details(data.get("access_token", None), - data.get("homeserver", None)) + await client.update_access_details( + data.get("access_token"), data.get("homeserver"), data.get("device_id") + ) except MatrixInvalidToken: return resp.bad_client_access_token except MatrixRequestError: + log.warning( + f"Failed to get whoami from homeserver to update client details", exc_info=True + ) return resp.bad_client_access_details except MatrixConnectionError: + log.warning(f"Failed to connect to homeserver to update client details", exc_info=True) return resp.bad_client_connection_details except ValueError as e: - return resp.mxid_mismatch(str(e)[len("MXID mismatch: "):]) - with client.db_instance.edit_mode(): - await client.update_avatar_url(data.get("avatar_url", None)) - await client.update_displayname(data.get("displayname", None)) - await client.update_started(data.get("started", None)) - client.enabled = data.get("enabled", client.enabled) - client.autojoin = data.get("autojoin", client.autojoin) - client.online = data.get("online", client.online) - client.sync = data.get("sync", client.sync) - return resp.updated(client.to_dict()) + str_err = str(e) + if str_err.startswith("MXID mismatch"): + return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :]) + elif str_err.startswith("Device ID mismatch"): + return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :]) + await client.update_avatar_url(data.get("avatar_url"), save=False) + await client.update_displayname(data.get("displayname"), save=False) + await client.update_started(data.get("started")) + await client.update_enabled(data.get("enabled"), save=False) + await client.update_autojoin(data.get("autojoin"), save=False) + await client.update_online(data.get("online"), save=False) + await client.update_sync(data.get("sync"), save=False) + await client.update() + return resp.updated(client.to_dict(), is_login=is_login) + + +async def _create_or_update_client( + user_id: UserID, data: dict, is_login: bool = False +) -> web.Response: + client = await Client.get(user_id) + if not client: + return await _create_client(user_id, data) + else: + return await _update_client(client, data, is_login=is_login) @routes.post("/client/new") @@ -107,37 +140,33 @@ async def create_client(request: web.Request) -> web.Response: @routes.put("/client/{id}") async def update_client(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + user_id = request.match_info["id"] try: data = await request.json() except JSONDecodeError: return resp.body_not_json - if not client: - return await _create_client(user_id, data) - else: - return await _update_client(client, data) + return await _create_or_update_client(user_id, data) @routes.delete("/client/{id}") async def delete_client(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + user_id = request.match_info["id"] + client = await Client.get(user_id) if not client: return resp.client_not_found if len(client.references) > 0: return resp.client_in_use if client.started: await client.stop() - client.delete() + await client.delete() return resp.deleted @routes.post("/client/{id}/clearcache") async def clear_client_cache(request: web.Request) -> web.Response: - user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + user_id = request.match_info["id"] + client = await Client.get(user_id) if not client: return resp.client_not_found - client.clear_cache() + await client.clear_cache() return resp.ok diff --git a/maubot/management/api/client_auth.py b/maubot/management/api/client_auth.py index 8a007f0..4e5e201 100644 --- a/maubot/management/api/client_auth.py +++ b/maubot/management/api/client_auth.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,31 +13,105 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Tuple, NamedTuple, Optional -from json import JSONDecodeError +from __future__ import annotations + +from typing import NamedTuple from http import HTTPStatus +from json import JSONDecodeError +import asyncio import hashlib +import hmac import random import string -import hmac from aiohttp import web -from mautrix.api import HTTPAPI, Path, Method -from mautrix.errors import MatrixRequestError +from yarl import URL -from .base import routes, get_config, get_loop +from mautrix.api import Method, Path, SynapseAdminPath +from mautrix.client import ClientAPI +from mautrix.errors import MatrixRequestError +from mautrix.types import LoginResponse, LoginType + +from .base import get_config, routes +from .client import _create_client, _create_or_update_client from .responses import resp -def registration_secrets() -> Dict[str, Dict[str, str]]: - return get_config()["registration_secrets"] +def known_homeservers() -> dict[str, dict[str, str]]: + return get_config()["homeservers"] -def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool = False, user_type: str = None): +@routes.get("/client/auth/servers") +async def get_known_servers(_: web.Request) -> web.Response: + return web.json_response({key: value["url"] for key, value in known_homeservers().items()}) + + +class AuthRequestInfo(NamedTuple): + server_name: str + client: ClientAPI + secret: str + username: str + password: str + user_type: str + device_name: str + update_client: bool + sso: bool + + +truthy_strings = ("1", "true", "yes") + + +async def read_client_auth_request( + request: web.Request, +) -> tuple[AuthRequestInfo | None, web.Response | None]: + server_name = request.match_info.get("server", None) + server = known_homeservers().get(server_name, None) + if not server: + return None, resp.server_not_found + try: + body = await request.json() + except JSONDecodeError: + return None, resp.body_not_json + sso = request.query.get("sso", "").lower() in truthy_strings + try: + username = body["username"] + password = body["password"] + except KeyError: + if not sso: + return None, resp.username_or_password_missing + username = password = None + try: + base_url = server["url"] + except KeyError: + return None, resp.invalid_server + return ( + AuthRequestInfo( + server_name=server_name, + client=ClientAPI(base_url=base_url), + secret=server.get("secret"), + username=username, + password=password, + user_type=body.get("user_type", "bot"), + device_name=body.get("device_name", "Maubot"), + update_client=request.query.get("update_client", "").lower() in truthy_strings, + sso=sso, + ), + None, + ) + + +def generate_mac( + secret: str, + nonce: str, + username: str, + password: str, + admin: bool = False, + user_type: str = None, +) -> str: mac = hmac.new(key=secret.encode("utf-8"), digestmod=hashlib.sha1) mac.update(nonce.encode("utf-8")) mac.update(b"\x00") - mac.update(user.encode("utf-8")) + mac.update(username.encode("utf-8")) mac.update(b"\x00") mac.update(password.encode("utf-8")) mac.update(b"\x00") @@ -48,86 +122,152 @@ def generate_mac(secret: str, nonce: str, user: str, password: str, admin: bool return mac.hexdigest() -@routes.get("/client/auth/servers") -async def get_registerable_servers(_: web.Request) -> web.Response: - return web.json_response({key: value["url"] for key, value in registration_secrets().items()}) - - -AuthRequestInfo = NamedTuple("AuthRequestInfo", api=HTTPAPI, secret=str, username=str, - password=str, user_type=str) - - -async def read_client_auth_request(request: web.Request) -> Tuple[Optional[AuthRequestInfo], - Optional[web.Response]]: - server_name = request.match_info.get("server", None) - server = registration_secrets().get(server_name, None) - if not server: - return None, resp.server_not_found - try: - body = await request.json() - except JSONDecodeError: - return None, resp.body_not_json - try: - username = body["username"] - password = body["password"] - except KeyError: - return None, resp.username_or_password_missing - try: - base_url = server["url"] - secret = server["secret"] - except KeyError: - return None, resp.invalid_server - api = HTTPAPI(base_url, "", loop=get_loop()) - user_type = body.get("user_type", "bot") - return AuthRequestInfo(api, secret, username, password, user_type), None - - @routes.post("/client/auth/{server}/register") async def register(request: web.Request) -> web.Response: - info, err = await read_client_auth_request(request) + req, err = await read_client_auth_request(request) if err is not None: return err - api, secret, username, password, user_type = info - res = await api.request(Method.GET, Path.admin.register) - nonce = res["nonce"] - mac = generate_mac(secret, nonce, username, password, user_type=user_type) + if req.sso: + return resp.registration_no_sso + elif not req.secret: + return resp.registration_secret_not_found + path = SynapseAdminPath.v1.register + res = await req.client.api.request(Method.GET, path) + content = { + "nonce": res["nonce"], + "username": req.username, + "password": req.password, + "admin": False, + "user_type": req.user_type, + } + content["mac"] = generate_mac(**content, secret=req.secret) try: - return web.json_response(await api.request(Method.POST, Path.admin.register, content={ - "nonce": nonce, - "username": username, - "password": password, - "admin": False, - "mac": mac, - # Older versions of synapse will ignore this field if it is None - "user_type": user_type, - })) + raw_res = await req.client.api.request(Method.POST, path, content=content) except MatrixRequestError as e: - return web.json_response({ - "errcode": e.errcode, - "error": e.message, - "http_status": e.http_status, - }, status=HTTPStatus.INTERNAL_SERVER_ERROR) + return web.json_response( + { + "errcode": e.errcode, + "error": e.message, + "http_status": e.http_status, + }, + status=HTTPStatus.INTERNAL_SERVER_ERROR, + ) + login_res = LoginResponse.deserialize(raw_res) + if req.update_client: + return await _create_client( + login_res.user_id, + { + "homeserver": str(req.client.api.base_url), + "access_token": login_res.access_token, + "device_id": login_res.device_id, + }, + ) + return web.json_response(login_res.serialize()) @routes.post("/client/auth/{server}/login") async def login(request: web.Request) -> web.Response: - info, err = await read_client_auth_request(request) + req, err = await read_client_auth_request(request) if err is not None: return err - api, _, username, password, _ = info - device_id = ''.join(random.choices(string.ascii_uppercase + string.digits, k=8)) + if req.sso: + return await _do_sso(req) + else: + return await _do_login(req) + + +async def _do_sso(req: AuthRequestInfo) -> web.Response: + flows = await req.client.get_login_flows() + if not flows.supports_type(LoginType.SSO): + return resp.sso_not_supported + waiter_id = "".join(random.choices(string.ascii_lowercase + string.digits, k=16)) + cfg = get_config() + public_url = ( + URL(cfg["server.public_url"]) + / "_matrix/maubot/v1/client/auth_external_sso/complete" + / waiter_id + ) + sso_url = req.client.api.base_url.with_path(str(Path.v3.login.sso.redirect)).with_query( + {"redirectUrl": str(public_url)} + ) + sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future() + return web.json_response({"sso_url": str(sso_url), "id": waiter_id}) + + +async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response: + device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8)) + device_id = f"maubot_{device_id}" try: - return web.json_response(await api.request(Method.POST, Path.login, content={ - "type": "m.login.password", - "identifier": { - "type": "m.id.user", - "user": username, - }, - "password": password, - "device_id": f"maubot_{device_id}", - })) + if req.sso: + res = await req.client.login( + token=login_token, + login_type=LoginType.TOKEN, + device_id=device_id, + store_access_token=False, + initial_device_display_name=req.device_name, + ) + else: + res = await req.client.login( + identifier=req.username, + login_type=LoginType.PASSWORD, + password=req.password, + device_id=device_id, + initial_device_display_name=req.device_name, + store_access_token=False, + ) except MatrixRequestError as e: - return web.json_response({ - "errcode": e.errcode, - "error": e.message, - }, status=e.http_status) + return web.json_response( + { + "errcode": e.errcode, + "error": e.message, + }, + status=e.http_status, + ) + if req.update_client: + return await _create_or_update_client( + res.user_id, + { + "homeserver": str(req.client.api.base_url), + "access_token": res.access_token, + "device_id": res.device_id, + }, + is_login=True, + ) + return web.json_response(res.serialize()) + + +sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {} + + +@routes.post("/client/auth/{server}/sso/{id}/wait") +async def wait_sso(request: web.Request) -> web.Response: + waiter_id = request.match_info["id"] + req, fut = sso_waiters[waiter_id] + try: + login_token = await fut + finally: + sso_waiters.pop(waiter_id, None) + return await _do_login(req, login_token) + + +@routes.get("/client/auth_external_sso/complete/{id}") +async def complete_sso(request: web.Request) -> web.Response: + try: + _, fut = sso_waiters[request.match_info["id"]] + except KeyError: + return web.Response(status=404, text="Invalid session ID\n") + if fut.cancelled(): + return web.Response(status=200, text="The login was cancelled from the Maubot client\n") + elif fut.done(): + return web.Response(status=200, text="The login token was already received\n") + try: + fut.set_result(request.query["loginToken"]) + except KeyError: + return web.Response(status=400, text="Missing loginToken query parameter\n") + except asyncio.InvalidStateError: + return web.Response(status=500, text="Invalid state\n") + return web.Response( + status=200, + text="Login token received, please return to your Maubot client. " + "This tab can be closed.\n", + ) diff --git a/maubot/management/api/client_proxy.py b/maubot/management/api/client_proxy.py index 8c293cd..3fa682b 100644 --- a/maubot/management/api/client_proxy.py +++ b/maubot/management/api/client_proxy.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from aiohttp import web, client as http +from aiohttp import client as http, web from ...client import Client from .base import routes @@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024 @routes.view("/proxy/{id}/{path:_matrix/.+}") async def proxy(request: web.Request) -> web.StreamResponse: user_id = request.match_info.get("id", None) - client = Client.get(user_id, None) + client = await Client.get(user_id) if not client: return resp.client_not_found @@ -45,8 +45,9 @@ async def proxy(request: web.Request) -> web.StreamResponse: headers["X-Forwarded-For"] = f"{host}:{port}" data = await request.read() - async with http.request(request.method, f"{client.homeserver}/{path}", headers=headers, - params=query, data=data) as proxy_resp: + async with http.request( + request.method, f"{client.homeserver}/{path}", headers=headers, params=query, data=data + ) as proxy_resp: response = web.StreamResponse(status=proxy_resp.status, headers=proxy_resp.headers) await response.prepare(request) async for chunk in proxy_resp.content.iter_chunked(PROXY_CHUNK_SIZE): diff --git a/maubot/management/api/dev_open.py b/maubot/management/api/dev_open.py index 323c515..2881d46 100644 --- a/maubot/management/api/dev_open.py +++ b/maubot/management/api/dev_open.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,11 +14,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from string import Template -from subprocess import run +import asyncio import re -from ruamel.yaml import YAML from aiohttp import web +from ruamel.yaml import YAML from .base import routes @@ -27,9 +27,7 @@ enabled = False @routes.get("/debug/open") async def check_enabled(_: web.Request) -> web.Response: - return web.json_response({ - "enabled": enabled, - }) + return web.json_response({"enabled": enabled}) try: @@ -40,7 +38,6 @@ try: editor_command = Template(cfg["editor"]) pathmap = [(re.compile(item["find"]), item["replace"]) for item in cfg["pathmap"]] - @routes.post("/debug/open") async def open_file(request: web.Request) -> web.Response: data = await request.json() @@ -51,13 +48,9 @@ try: cmd = editor_command.substitute(path=path, line=data["line"]) except (KeyError, ValueError): return web.Response(status=400) - res = run(cmd, shell=True) - return web.json_response({ - "return": res.returncode, - "stdout": res.stdout, - "stderr": res.stderr - }) - + res = await asyncio.create_subprocess_shell(cmd) + stdout, stderr = await res.communicate() + return web.json_response({"return": res.returncode, "stdout": stdout, "stderr": stderr}) enabled = True except Exception: diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py index 91861af..4043221 100644 --- a/maubot/management/api/instance.py +++ b/maubot/management/api/instance.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -17,10 +17,9 @@ from json import JSONDecodeError from aiohttp import web -from ...db import DBPlugin +from ...client import Client from ...instance import PluginInstance from ...loader import PluginLoader -from ...client import Client from .base import routes from .responses import resp @@ -32,51 +31,50 @@ async def get_instances(_: web.Request) -> web.Response: @routes.get("/instance/{id}") async def get_instance(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "").lower() - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found return resp.found(instance.to_dict()) async def _create_instance(instance_id: str, data: dict) -> web.Response: - plugin_type = data.get("type", None) - primary_user = data.get("primary_user", None) + plugin_type = data.get("type") + primary_user = data.get("primary_user") if not plugin_type: return resp.plugin_type_required elif not primary_user: return resp.primary_user_required - elif not Client.get(primary_user): + elif not await Client.get(primary_user): return resp.primary_user_not_found try: PluginLoader.find(plugin_type) except KeyError: return resp.plugin_type_not_found - db_instance = DBPlugin(id=instance_id, type=plugin_type, enabled=data.get("enabled", True), - primary_user=primary_user, config=data.get("config", "")) - instance = PluginInstance(db_instance) - instance.load() - instance.db_instance.insert() + instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user) + instance.enabled = data.get("enabled", True) + instance.config_str = data.get("config") or "" + await instance.update() + await instance.load() await instance.start() return resp.created(instance.to_dict()) async def _update_instance(instance: PluginInstance, data: dict) -> web.Response: - if not await instance.update_primary_user(data.get("primary_user", None)): + if not await instance.update_primary_user(data.get("primary_user")): return resp.primary_user_not_found - with instance.db_instance.edit_mode(): - instance.update_id(data.get("id", None)) - instance.update_enabled(data.get("enabled", None)) - instance.update_config(data.get("config", None)) - await instance.update_started(data.get("started", None)) - await instance.update_type(data.get("type", None)) - return resp.updated(instance.to_dict()) + await instance.update_id(data.get("id")) + await instance.update_enabled(data.get("enabled")) + await instance.update_config(data.get("config")) + await instance.update_started(data.get("started")) + await instance.update_type(data.get("type")) + return resp.updated(instance.to_dict()) @routes.put("/instance/{id}") async def update_instance(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "").lower() - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) try: data = await request.json() except JSONDecodeError: @@ -89,11 +87,11 @@ async def update_instance(request: web.Request) -> web.Response: @routes.delete("/instance/{id}") async def delete_instance(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "").lower() - instance = PluginInstance.get(instance_id) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found if instance.started: await instance.stop() - instance.delete() + await instance.delete() return resp.deleted diff --git a/maubot/management/api/instance_database.py b/maubot/management/api/instance_database.py index bc3baf3..2f8c37a 100644 --- a/maubot/management/api/instance_database.py +++ b/maubot/management/api/instance_database.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,80 +13,67 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Union, TYPE_CHECKING +from __future__ import annotations + from datetime import datetime from aiohttp import web -from sqlalchemy import Table, Column, asc, desc, exc -from sqlalchemy.orm import Query -from sqlalchemy.engine.result import ResultProxy, RowProxy +from asyncpg import PostgresError +import aiosqlite + +from mautrix.util.async_db import Database from ...instance import PluginInstance +from ...lib.optionalalchemy import Engine, IntegrityError, OperationalError, asc, desc from .base import routes from .responses import resp @routes.get("/instance/{id}/database") async def get_database(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "") - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found elif not instance.inst_db: return resp.plugin_has_no_database - if TYPE_CHECKING: - table: Table - column: Column - return web.json_response({ - table.name: { - "columns": { - column.name: { - "type": str(column.type), - "unique": column.unique or False, - "default": column.default, - "nullable": column.nullable, - "primary": column.primary_key, - "autoincrement": column.autoincrement, - } for column in table.columns - }, - } for table in instance.get_db_tables().values() - }) - - -def check_type(val): - if isinstance(val, datetime): - return val.isoformat() - return val + return web.json_response(await instance.get_db_tables()) @routes.get("/instance/{id}/database/{table}") async def get_table(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "") - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found elif not instance.inst_db: return resp.plugin_has_no_database - tables = instance.get_db_tables() + tables = await instance.get_db_tables() try: table = tables[request.match_info.get("table", "")] except KeyError: return resp.table_not_found try: order = [tuple(order.split(":")) for order in request.query.getall("order")] - order = [(asc if sort.lower() == "asc" else desc)(table.columns[column]) - if sort else table.columns[column] - for column, sort in order] + order = [ + ( + (asc if sort.lower() == "asc" else desc)(table.columns[column]) + if sort + else table.columns[column] + ) + for column, sort in order + ] except KeyError: order = [] - limit = int(request.query.get("limit", 100)) - return execute_query(instance, table.select().order_by(*order).limit(limit)) + limit = int(request.query.get("limit", "100")) + if isinstance(instance.inst_db, Engine): + return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit)) @routes.post("/instance/{id}/database/query") async def query(request: web.Request) -> web.Response: - instance_id = request.match_info.get("id", "") - instance = PluginInstance.get(instance_id, None) + instance_id = request.match_info["id"].lower() + instance = await PluginInstance.get(instance_id) if not instance: return resp.instance_not_found elif not instance.inst_db: @@ -96,28 +83,76 @@ async def query(request: web.Request) -> web.Response: sql_query = data["query"] except KeyError: return resp.query_missing - return execute_query(instance, sql_query, - rows_as_dict=data.get("rows_as_dict", False)) + rows_as_dict = data.get("rows_as_dict", False) + if isinstance(instance.inst_db, Engine): + return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict) + elif isinstance(instance.inst_db, Database): + try: + return await _execute_query_asyncpg(instance, sql_query, rows_as_dict) + except (PostgresError, aiosqlite.Error) as e: + return resp.sql_error(e, sql_query) + else: + return resp.unsupported_plugin_database -def execute_query(instance: PluginInstance, sql_query: Union[str, Query], - rows_as_dict: bool = False) -> web.Response: +def check_type(val): + if isinstance(val, datetime): + return val.isoformat() + return val + + +async def _execute_query_asyncpg( + instance: PluginInstance, sql_query: str, rows_as_dict: bool = False +) -> web.Response: + data = {"ok": True, "query": sql_query} + if sql_query.upper().startswith("SELECT"): + res = await instance.inst_db.fetch(sql_query) + data["rows"] = [ + ( + {key: check_type(value) for key, value in row.items()} + if rows_as_dict + else [check_type(value) for value in row] + ) + for row in res + ] + if len(res) > 0: + # TODO can we find column names when there are no rows? + data["columns"] = list(res[0].keys()) + else: + res = await instance.inst_db.execute(sql_query) + if isinstance(res, str): + data["status_msg"] = res + elif isinstance(res, aiosqlite.Cursor): + data["rowcount"] = res.rowcount + # data["inserted_primary_key"] = res.lastrowid + else: + data["status_msg"] = "unknown status" + return web.json_response(data) + + +def _execute_query_sqlalchemy( + instance: PluginInstance, sql_query: str, rows_as_dict: bool = False +) -> web.Response: + assert isinstance(instance.inst_db, Engine) try: - res: ResultProxy = instance.inst_db.execute(sql_query) - except exc.IntegrityError as e: + res = instance.inst_db.execute(sql_query) + except IntegrityError as e: return resp.sql_integrity_error(e, sql_query) - except exc.OperationalError as e: + except OperationalError as e: return resp.sql_operational_error(e, sql_query) data = { "ok": True, "query": str(sql_query), } if res.returns_rows: - row: RowProxy - data["rows"] = [({key: check_type(value) for key, value in row.items()} - if rows_as_dict - else [check_type(value) for value in row]) - for row in res] + data["rows"] = [ + ( + {key: check_type(value) for key, value in row.items()} + if rows_as_dict + else [check_type(value) for value in row] + ) + for row in res + ] data["columns"] = res.keys() else: data["rowcount"] = res.rowcount diff --git a/maubot/management/api/log.py b/maubot/management/api/log.py index d6ec092..14c80cd 100644 --- a/maubot/management/api/log.py +++ b/maubot/management/api/log.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,31 +13,62 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Deque, List -from datetime import datetime +from __future__ import annotations + from collections import deque -import logging +from datetime import datetime import asyncio +import logging -from aiohttp import web +from aiohttp import web, web_ws + +from mautrix.util import background_task -from .base import routes, get_loop from .auth import is_valid_token +from .base import routes -BUILTIN_ATTRS = {"args", "asctime", "created", "exc_info", "exc_text", "filename", "funcName", - "levelname", "levelno", "lineno", "module", "msecs", "message", "msg", "name", - "pathname", "process", "processName", "relativeCreated", "stack_info", "thread", - "threadName"} -INCLUDE_ATTRS = {"filename", "funcName", "levelname", "levelno", "lineno", "module", "name", - "pathname"} +BUILTIN_ATTRS = { + "args", + "asctime", + "created", + "exc_info", + "exc_text", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "message", + "msg", + "name", + "pathname", + "process", + "processName", + "relativeCreated", + "stack_info", + "thread", + "threadName", +} +INCLUDE_ATTRS = { + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "name", + "pathname", +} EXCLUDE_ATTRS = BUILTIN_ATTRS - INCLUDE_ATTRS MAX_LINES = 2048 class LogCollector(logging.Handler): - lines: Deque[dict] + lines: deque[dict] formatter: logging.Formatter - listeners: List[web.WebSocketResponse] + listeners: list[web.WebSocketResponse] loop: asyncio.AbstractEventLoop def __init__(self, level=logging.NOTSET) -> None: @@ -56,9 +87,7 @@ class LogCollector(logging.Handler): # JSON conversion based on Marsel Mavletkulov's json-log-formatter (MIT license) # https://github.com/marselester/json-log-formatter content = { - name: value - for name, value in record.__dict__.items() - if name not in EXCLUDE_ATTRS + name: value for name, value in record.__dict__.items() if name not in EXCLUDE_ATTRS } content["id"] = str(record.relativeCreated) content["msg"] = record.getMessage() @@ -82,18 +111,18 @@ class LogCollector(logging.Handler): handler = LogCollector() -log_root = logging.getLogger("maubot") log = logging.getLogger("maubot.server.websocket") sockets = [] def init(loop: asyncio.AbstractEventLoop) -> None: - log_root.addHandler(handler) + logging.root.addHandler(handler) handler.loop = loop async def stop_all() -> None: - log_root.removeHandler(handler) + log.debug("Closing log listener websockets") + logging.root.removeHandler(handler) for socket in sockets: try: await socket.close(code=1012) @@ -110,14 +139,15 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse: authenticated = False async def close_if_not_authenticated(): - await asyncio.sleep(5, loop=get_loop()) + await asyncio.sleep(5) if not authenticated: await ws.close(code=4000) log.debug(f"Connection from {request.remote} terminated due to no authentication") - asyncio.ensure_future(close_if_not_authenticated()) + background_task.create(close_if_not_authenticated()) try: + msg: web_ws.WSMessage async for msg in ws: if msg.type != web.WSMsgType.TEXT: continue diff --git a/maubot/management/api/login.py b/maubot/management/api/login.py index 21f9342..bfb2f6a 100644 --- a/maubot/management/api/login.py +++ b/maubot/management/api/login.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -17,9 +17,10 @@ import json from aiohttp import web -from .base import routes, get_config -from .responses import resp from .auth import create_token +from .base import get_config, routes +from .responses import resp + @routes.post("/auth/login") async def login(request: web.Request) -> web.Response: diff --git a/maubot/management/api/middleware.py b/maubot/management/api/middleware.py index ff6b4c1..17141fa 100644 --- a/maubot/management/api/middleware.py +++ b/maubot/management/api/middleware.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,14 +13,15 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Callable, Awaitable +from typing import Awaitable, Callable +import base64 import logging from aiohttp import web -from .responses import resp from .auth import check_token from .base import get_config +from .responses import resp Handler = Callable[[web.Request], Awaitable[web.Response]] log = logging.getLogger("maubot.server") @@ -28,8 +29,13 @@ log = logging.getLogger("maubot.server") @web.middleware async def auth(request: web.Request, handler: Handler) -> web.Response: - subpath = request.path[len(get_config()["server.base_path"]):] - if subpath.startswith("/auth/") or subpath == "/features" or subpath == "/logs": + subpath = request.path[len("/_matrix/maubot/v1") :] + if ( + subpath.startswith("/auth/") + or subpath.startswith("/client/auth_external_sso/complete/") + or subpath == "/features" + or subpath == "/logs" + ): return await handler(request) err = check_token(request) if err is not None: @@ -46,10 +52,18 @@ async def error(request: web.Request, handler: Handler) -> web.Response: return resp.path_not_found elif ex.status_code == 405: return resp.method_not_allowed - return web.json_response({ - "error": f"Unhandled HTTP {ex.status}", - "errcode": f"unhandled_http_{ex.status}", - }, status=ex.status) + return web.json_response( + { + "httpexception": { + "headers": {key: value for key, value in ex.headers.items()}, + "class": type(ex).__name__, + "body": ex.text or base64.b64encode(ex.body), + }, + "error": f"Unhandled HTTP {ex.status}: {ex.text[:128] or 'non-text response'}", + "errcode": f"unhandled_http_{ex.status}", + }, + status=ex.status, + ) except Exception: log.exception("Error in handler") return resp.internal_server_error diff --git a/maubot/management/api/plugin.py b/maubot/management/api/plugin.py index 4429e11..94d8d9d 100644 --- a/maubot/management/api/plugin.py +++ b/maubot/management/api/plugin.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -17,9 +17,9 @@ import traceback from aiohttp import web -from ...loader import PluginLoader, MaubotZipImportError -from .responses import resp +from ...loader import MaubotZipImportError, PluginLoader from .base import routes +from .responses import resp @routes.get("/plugins") @@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response: @routes.get("/plugin/{id}") async def get_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info.get("id", None) - plugin = PluginLoader.id_cache.get(plugin_id, None) + plugin_id = request.match_info["id"] + plugin = PluginLoader.id_cache.get(plugin_id) if not plugin: return resp.plugin_not_found return resp.found(plugin.to_dict()) @@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response: @routes.delete("/plugin/{id}") async def delete_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info.get("id", None) - plugin = PluginLoader.id_cache.get(plugin_id, None) + plugin_id = request.match_info["id"] + plugin = PluginLoader.id_cache.get(plugin_id) if not plugin: return resp.plugin_not_found elif len(plugin.references) > 0: @@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response: @routes.post("/plugin/{id}/reload") async def reload_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info.get("id", None) - plugin = PluginLoader.id_cache.get(plugin_id, None) + plugin_id = request.match_info["id"] + plugin = PluginLoader.id_cache.get(plugin_id) if not plugin: return resp.plugin_not_found diff --git a/maubot/management/api/plugin_upload.py b/maubot/management/api/plugin_upload.py index 7b5b5de..4cd2c47 100644 --- a/maubot/management/api/plugin_upload.py +++ b/maubot/management/api/plugin_upload.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -15,27 +15,39 @@ # along with this program. If not, see . from io import BytesIO from time import time -import traceback +import logging import os.path import re +import traceback from aiohttp import web from packaging.version import Version -from ...loader import PluginLoader, ZippedPluginLoader, MaubotZipImportError +from ...loader import DatabaseType, MaubotZipImportError, PluginLoader, ZippedPluginLoader +from .base import get_config, routes from .responses import resp -from .base import routes, get_config + +try: + import sqlalchemy + + has_alchemy = True +except ImportError: + has_alchemy = False + +log = logging.getLogger("maubot.server.upload") @routes.put("/plugin/{id}") async def put_plugin(request: web.Request) -> web.Response: - plugin_id = request.match_info.get("id", None) + plugin_id = request.match_info["id"] content = await request.read() file = BytesIO(content) try: - pid, version = ZippedPluginLoader.verify_meta(file) + pid, version, db_type = ZippedPluginLoader.verify_meta(file) except MaubotZipImportError as e: return resp.plugin_import_error(str(e), traceback.format_exc()) + if db_type == DatabaseType.SQLALCHEMY and not has_alchemy: + return resp.sqlalchemy_not_installed if pid != plugin_id: return resp.pid_mismatch plugin = PluginLoader.id_cache.get(plugin_id, None) @@ -52,9 +64,11 @@ async def upload_plugin(request: web.Request) -> web.Response: content = await request.read() file = BytesIO(content) try: - pid, version = ZippedPluginLoader.verify_meta(file) + pid, version, db_type = ZippedPluginLoader.verify_meta(file) except MaubotZipImportError as e: return resp.plugin_import_error(str(e), traceback.format_exc()) + if db_type == DatabaseType.SQLALCHEMY and not has_alchemy: + return resp.sqlalchemy_not_installed plugin = PluginLoader.id_cache.get(pid, None) if not plugin: return await upload_new_plugin(content, pid, version) @@ -78,15 +92,20 @@ async def upload_new_plugin(content: bytes, pid: str, version: Version) -> web.R return resp.created(plugin.to_dict()) -async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, - new_version: Version) -> web.Response: +async def upload_replacement_plugin( + plugin: ZippedPluginLoader, content: bytes, new_version: Version +) -> web.Response: dirname = os.path.dirname(plugin.path) old_filename = os.path.basename(plugin.path) if str(plugin.meta.version) in old_filename: - replacement = (str(new_version) if plugin.meta.version != new_version - else f"{new_version}-ts{int(time())}") - filename = re.sub(f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", - replacement, old_filename) + replacement = ( + str(new_version) + if plugin.meta.version != new_version + else f"{new_version}-ts{int(time() * 1000)}" + ) + filename = re.sub( + f"{re.escape(str(plugin.meta.version))}(-ts[0-9]+)?", replacement, old_filename + ) else: filename = old_filename.rstrip(".mbp") filename = f"{filename}-v{new_version}.mbp" @@ -98,12 +117,29 @@ async def upload_replacement_plugin(plugin: ZippedPluginLoader, content: bytes, try: await plugin.reload(new_path=path) except MaubotZipImportError as e: + log.exception(f"Error loading updated version of {plugin.meta.id}, rolling back") try: await plugin.reload(new_path=old_path) await plugin.start_instances() except MaubotZipImportError: - pass + log.warning(f"Failed to roll back update of {plugin.meta.id}", exc_info=True) + finally: + ZippedPluginLoader.trash(path, reason="failed_update") return resp.plugin_import_error(str(e), traceback.format_exc()) - await plugin.start_instances() + try: + await plugin.start_instances() + except Exception as e: + log.exception(f"Error starting {plugin.meta.id} instances after update, rolling back") + try: + await plugin.stop_instances() + await plugin.reload(new_path=old_path) + await plugin.start_instances() + except Exception: + log.warning(f"Failed to roll back update of {plugin.meta.id}", exc_info=True) + finally: + ZippedPluginLoader.trash(path, reason="failed_update") + return resp.plugin_reload_error(str(e), traceback.format_exc()) + + log.debug(f"Successfully updated {plugin.meta.id}, moving old version to trash") ZippedPluginLoader.trash(old_path, reason="update") return resp.updated(plugin.to_dict()) diff --git a/maubot/management/api/responses.py b/maubot/management/api/responses.py index b30d49d..0f22caa 100644 --- a/maubot/management/api/responses.py +++ b/maubot/management/api/responses.py @@ -1,5 +1,5 @@ # maubot - A plugin-based Matrix bot system. -# Copyright (C) 2019 Tulir Asokan +# Copyright (C) 2022 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -13,271 +13,457 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . +from __future__ import annotations + +from typing import TYPE_CHECKING from http import HTTPStatus from aiohttp import web -from sqlalchemy.exc import OperationalError, IntegrityError +from asyncpg import PostgresError +import aiosqlite + +if TYPE_CHECKING: + from sqlalchemy.exc import IntegrityError, OperationalError class _Response: @property def body_not_json(self) -> web.Response: - return web.json_response({ - "error": "Request body is not JSON", - "errcode": "body_not_json", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Request body is not JSON", + "errcode": "body_not_json", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def plugin_type_required(self) -> web.Response: - return web.json_response({ - "error": "Plugin type is required when creating plugin instances", - "errcode": "plugin_type_required", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Plugin type is required when creating plugin instances", + "errcode": "plugin_type_required", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def primary_user_required(self) -> web.Response: - return web.json_response({ - "error": "Primary user is required when creating plugin instances", - "errcode": "primary_user_required", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Primary user is required when creating plugin instances", + "errcode": "primary_user_required", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def bad_client_access_token(self) -> web.Response: - return web.json_response({ - "error": "Invalid access token", - "errcode": "bad_client_access_token", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Invalid access token", + "errcode": "bad_client_access_token", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def bad_client_access_details(self) -> web.Response: - return web.json_response({ - "error": "Invalid homeserver or access token", - "errcode": "bad_client_access_details" - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Invalid homeserver or access token", + "errcode": "bad_client_access_details", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def bad_client_connection_details(self) -> web.Response: - return web.json_response({ - "error": "Could not connect to homeserver", - "errcode": "bad_client_connection_details" - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Could not connect to homeserver", + "errcode": "bad_client_connection_details", + }, + status=HTTPStatus.BAD_REQUEST, + ) def mxid_mismatch(self, found: str) -> web.Response: - return web.json_response({ - "error": "The Matrix user ID of the client and the user ID of the access token don't " - f"match. Access token is for user {found}", - "errcode": "mxid_mismatch", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": ( + "The Matrix user ID of the client and the user ID of the access token don't " + f"match. Access token is for user {found}" + ), + "errcode": "mxid_mismatch", + }, + status=HTTPStatus.BAD_REQUEST, + ) + + def device_id_mismatch(self, found: str) -> web.Response: + return web.json_response( + { + "error": ( + "The Matrix device ID of the client and the device ID of the access token " + f"don't match. Access token is for device {found}" + ), + "errcode": "mxid_mismatch", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def pid_mismatch(self) -> web.Response: - return web.json_response({ - "error": "The ID in the path does not match the ID of the uploaded plugin", - "errcode": "pid_mismatch", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "The ID in the path does not match the ID of the uploaded plugin", + "errcode": "pid_mismatch", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def username_or_password_missing(self) -> web.Response: - return web.json_response({ - "error": "Username or password missing", - "errcode": "username_or_password_missing", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Username or password missing", + "errcode": "username_or_password_missing", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def query_missing(self) -> web.Response: - return web.json_response({ - "error": "Query missing", - "errcode": "query_missing", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Query missing", + "errcode": "query_missing", + }, + status=HTTPStatus.BAD_REQUEST, + ) + + @staticmethod + def sql_error(error: PostgresError | aiosqlite.Error, query: str) -> web.Response: + return web.json_response( + { + "ok": False, + "query": query, + "error": str(error), + "errcode": "sql_error", + }, + status=HTTPStatus.BAD_REQUEST, + ) @staticmethod def sql_operational_error(error: OperationalError, query: str) -> web.Response: - return web.json_response({ - "ok": False, - "query": query, - "error": str(error.orig), - "full_error": str(error), - "errcode": "sql_operational_error", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "ok": False, + "query": query, + "error": str(error.orig), + "full_error": str(error), + "errcode": "sql_operational_error", + }, + status=HTTPStatus.BAD_REQUEST, + ) @staticmethod def sql_integrity_error(error: IntegrityError, query: str) -> web.Response: - return web.json_response({ - "ok": False, - "query": query, - "error": str(error.orig), - "full_error": str(error), - "errcode": "sql_integrity_error", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "ok": False, + "query": query, + "error": str(error.orig), + "full_error": str(error), + "errcode": "sql_integrity_error", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def bad_auth(self) -> web.Response: - return web.json_response({ - "error": "Invalid username or password", - "errcode": "invalid_auth", - }, status=HTTPStatus.UNAUTHORIZED) + return web.json_response( + { + "error": "Invalid username or password", + "errcode": "invalid_auth", + }, + status=HTTPStatus.UNAUTHORIZED, + ) @property def no_token(self) -> web.Response: - return web.json_response({ - "error": "Authorization token missing", - "errcode": "auth_token_missing", - }, status=HTTPStatus.UNAUTHORIZED) + return web.json_response( + { + "error": "Authorization token missing", + "errcode": "auth_token_missing", + }, + status=HTTPStatus.UNAUTHORIZED, + ) @property def invalid_token(self) -> web.Response: - return web.json_response({ - "error": "Invalid authorization token", - "errcode": "auth_token_invalid", - }, status=HTTPStatus.UNAUTHORIZED) + return web.json_response( + { + "error": "Invalid authorization token", + "errcode": "auth_token_invalid", + }, + status=HTTPStatus.UNAUTHORIZED, + ) @property def plugin_not_found(self) -> web.Response: - return web.json_response({ - "error": "Plugin not found", - "errcode": "plugin_not_found", - }, status=HTTPStatus.NOT_FOUND) + return web.json_response( + { + "error": "Plugin not found", + "errcode": "plugin_not_found", + }, + status=HTTPStatus.NOT_FOUND, + ) @property def client_not_found(self) -> web.Response: - return web.json_response({ - "error": "Client not found", - "errcode": "client_not_found", - }, status=HTTPStatus.NOT_FOUND) + return web.json_response( + { + "error": "Client not found", + "errcode": "client_not_found", + }, + status=HTTPStatus.NOT_FOUND, + ) @property def primary_user_not_found(self) -> web.Response: - return web.json_response({ - "error": "Client for given primary user not found", - "errcode": "primary_user_not_found", - }, status=HTTPStatus.NOT_FOUND) + return web.json_response( + { + "error": "Client for given primary user not found", + "errcode": "primary_user_not_found", + }, + status=HTTPStatus.NOT_FOUND, + ) @property def instance_not_found(self) -> web.Response: - return web.json_response({ - "error": "Plugin instance not found", - "errcode": "instance_not_found", - }, status=HTTPStatus.NOT_FOUND) + return web.json_response( + { + "error": "Plugin instance not found", + "errcode": "instance_not_found", + }, + status=HTTPStatus.NOT_FOUND, + ) @property def plugin_type_not_found(self) -> web.Response: - return web.json_response({ - "error": "Given plugin type not found", - "errcode": "plugin_type_not_found", - }, status=HTTPStatus.NOT_FOUND) + return web.json_response( + { + "error": "Given plugin type not found", + "errcode": "plugin_type_not_found", + }, + status=HTTPStatus.NOT_FOUND, + ) @property def path_not_found(self) -> web.Response: - return web.json_response({ - "error": "Resource not found", - "errcode": "resource_not_found", - }, status=HTTPStatus.NOT_FOUND) + return web.json_response( + { + "error": "Resource not found", + "errcode": "resource_not_found", + }, + status=HTTPStatus.NOT_FOUND, + ) @property def server_not_found(self) -> web.Response: - return web.json_response({ - "error": "Registration target server not found", - "errcode": "server_not_found", - }, status=HTTPStatus.NOT_FOUND) + return web.json_response( + { + "error": "Registration target server not found", + "errcode": "server_not_found", + }, + status=HTTPStatus.NOT_FOUND, + ) + + @property + def registration_secret_not_found(self) -> web.Response: + return web.json_response( + { + "error": "Config does not have a registration secret for that server", + "errcode": "registration_secret_not_found", + }, + status=HTTPStatus.NOT_FOUND, + ) + + @property + def registration_no_sso(self) -> web.Response: + return web.json_response( + { + "error": "The register operation is only for registering with a password", + "errcode": "registration_no_sso", + }, + status=HTTPStatus.BAD_REQUEST, + ) + + @property + def sso_not_supported(self) -> web.Response: + return web.json_response( + { + "error": "That server does not seem to support single sign-on", + "errcode": "sso_not_supported", + }, + status=HTTPStatus.FORBIDDEN, + ) @property def plugin_has_no_database(self) -> web.Response: - return web.json_response({ - "error": "Given plugin does not have a database", - "errcode": "plugin_has_no_database", - }) + return web.json_response( + { + "error": "Given plugin does not have a database", + "errcode": "plugin_has_no_database", + } + ) + + @property + def unsupported_plugin_database(self) -> web.Response: + return web.json_response( + { + "error": "The database type is not supported by this API", + "errcode": "unsupported_plugin_database", + } + ) + + @property + def sqlalchemy_not_installed(self) -> web.Response: + return web.json_response( + { + "error": "This plugin requires a legacy database, but SQLAlchemy is not installed", + "errcode": "unsupported_plugin_database", + }, + status=HTTPStatus.NOT_IMPLEMENTED, + ) @property def table_not_found(self) -> web.Response: - return web.json_response({ - "error": "Given table not found in plugin database", - "errcode": "table_not_found", - }) + return web.json_response( + { + "error": "Given table not found in plugin database", + "errcode": "table_not_found", + } + ) @property def method_not_allowed(self) -> web.Response: - return web.json_response({ - "error": "Method not allowed", - "errcode": "method_not_allowed", - }, status=HTTPStatus.METHOD_NOT_ALLOWED) + return web.json_response( + { + "error": "Method not allowed", + "errcode": "method_not_allowed", + }, + status=HTTPStatus.METHOD_NOT_ALLOWED, + ) @property def user_exists(self) -> web.Response: - return web.json_response({ - "error": "There is already a client with the user ID of that token", - "errcode": "user_exists", - }, status=HTTPStatus.CONFLICT) + return web.json_response( + { + "error": "There is already a client with the user ID of that token", + "errcode": "user_exists", + }, + status=HTTPStatus.CONFLICT, + ) @property def plugin_exists(self) -> web.Response: - return web.json_response({ - "error": "A plugin with the same ID as the uploaded plugin already exists", - "errcode": "plugin_exists" - }, status=HTTPStatus.CONFLICT) + return web.json_response( + { + "error": "A plugin with the same ID as the uploaded plugin already exists", + "errcode": "plugin_exists", + }, + status=HTTPStatus.CONFLICT, + ) @property def plugin_in_use(self) -> web.Response: - return web.json_response({ - "error": "Plugin instances of this type still exist", - "errcode": "plugin_in_use", - }, status=HTTPStatus.PRECONDITION_FAILED) + return web.json_response( + { + "error": "Plugin instances of this type still exist", + "errcode": "plugin_in_use", + }, + status=HTTPStatus.PRECONDITION_FAILED, + ) @property def client_in_use(self) -> web.Response: - return web.json_response({ - "error": "Plugin instances with this client as their primary user still exist", - "errcode": "client_in_use", - }, status=HTTPStatus.PRECONDITION_FAILED) + return web.json_response( + { + "error": "Plugin instances with this client as their primary user still exist", + "errcode": "client_in_use", + }, + status=HTTPStatus.PRECONDITION_FAILED, + ) @staticmethod def plugin_import_error(error: str, stacktrace: str) -> web.Response: - return web.json_response({ - "error": error, - "stacktrace": stacktrace, - "errcode": "plugin_invalid", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": error, + "stacktrace": stacktrace, + "errcode": "plugin_invalid", + }, + status=HTTPStatus.BAD_REQUEST, + ) @staticmethod def plugin_reload_error(error: str, stacktrace: str) -> web.Response: - return web.json_response({ - "error": error, - "stacktrace": stacktrace, - "errcode": "plugin_reload_fail", - }, status=HTTPStatus.INTERNAL_SERVER_ERROR) + return web.json_response( + { + "error": error, + "stacktrace": stacktrace, + "errcode": "plugin_reload_fail", + }, + status=HTTPStatus.INTERNAL_SERVER_ERROR, + ) @property def internal_server_error(self) -> web.Response: - return web.json_response({ - "error": "Internal server error", - "errcode": "internal_server_error", - }, status=HTTPStatus.INTERNAL_SERVER_ERROR) + return web.json_response( + { + "error": "Internal server error", + "errcode": "internal_server_error", + }, + status=HTTPStatus.INTERNAL_SERVER_ERROR, + ) @property def invalid_server(self) -> web.Response: - return web.json_response({ - "error": "Invalid registration server object in maubot configuration", - "errcode": "invalid_server", - }, status=HTTPStatus.INTERNAL_SERVER_ERROR) + return web.json_response( + { + "error": "Invalid registration server object in maubot configuration", + "errcode": "invalid_server", + }, + status=HTTPStatus.INTERNAL_SERVER_ERROR, + ) @property def unsupported_plugin_loader(self) -> web.Response: - return web.json_response({ - "error": "Existing plugin with same ID uses unsupported plugin loader", - "errcode": "unsupported_plugin_loader", - }, status=HTTPStatus.BAD_REQUEST) + return web.json_response( + { + "error": "Existing plugin with same ID uses unsupported plugin loader", + "errcode": "unsupported_plugin_loader", + }, + status=HTTPStatus.BAD_REQUEST, + ) @property def not_implemented(self) -> web.Response: - return web.json_response({ - "error": "Not implemented", - "errcode": "not_implemented", - }, status=HTTPStatus.NOT_IMPLEMENTED) + return web.json_response( + { + "error": "Not implemented", + "errcode": "not_implemented", + }, + status=HTTPStatus.NOT_IMPLEMENTED, + ) @property def ok(self) -> web.Response: - return web.json_response({ - "success": True, - }, status=HTTPStatus.OK) + return web.json_response( + {"success": True}, + status=HTTPStatus.OK, + ) @property def deleted(self) -> web.Response: @@ -287,19 +473,15 @@ class _Response: def found(data: dict) -> web.Response: return web.json_response(data, status=HTTPStatus.OK) - def updated(self, data: dict) -> web.Response: - return self.found(data) + @staticmethod + def updated(data: dict, is_login: bool = False) -> web.Response: + return web.json_response(data, status=HTTPStatus.ACCEPTED if is_login else HTTPStatus.OK) def logged_in(self, token: str) -> web.Response: - return self.found({ - "token": token, - }) + return self.found({"token": token}) def pong(self, user: str, features: dict) -> web.Response: - return self.found({ - "username": user, - "features": features, - }) + return self.found({"username": user, "features": features}) @staticmethod def created(data: dict) -> web.Response: diff --git a/maubot/management/api/spec.yaml b/maubot/management/api/spec.yaml index c6f5181..8529599 100644 --- a/maubot/management/api/spec.yaml +++ b/maubot/management/api/spec.yaml @@ -366,7 +366,7 @@ paths: schema: $ref: '#/components/schemas/MatrixClient' responses: - 200: + 202: description: Client updated content: application/json: @@ -454,6 +454,12 @@ paths: required: true schema: type: string + - name: update_client + in: query + description: Should maubot store the access details in a Client instead of returning them? + required: false + schema: + type: boolean post: operationId: client_auth_register summary: | @@ -475,18 +481,29 @@ paths: properties: access_token: type: string - example: token_here + example: syt_123_456_789 user_id: type: string example: '@putkiteippi:maunium.net' - home_server: - type: string - example: maunium.net device_id: type: string - example: device_id_here + example: maubot_F00BAR12 + 201: + description: Client created (when update_client is true) + content: + application/json: + schema: + $ref: '#/components/schemas/MatrixClient' 401: $ref: '#/components/responses/Unauthorized' + 409: + description: | + There is already a client with the user ID of that token. + This should usually not happen, because the user ID was just created. + content: + application/json: + schema: + $ref: '#/components/schemas/Error' 500: $ref: '#/components/responses/MatrixServerError' '/client/auth/{server}/login': @@ -497,6 +514,12 @@ paths: required: true schema: type: string + - name: update_client + in: query + description: Should maubot store the access details in a Client instead of returning them? + required: false + schema: + type: boolean post: operationId: client_auth_login summary: Log in to the given Matrix server via the maubot server @@ -519,10 +542,22 @@ paths: example: '@putkiteippi:maunium.net' access_token: type: string - example: token_here + example: syt_123_456_789 device_id: type: string - example: device_id_here + example: maubot_F00BAR12 + 201: + description: Client created (when update_client is true) + content: + application/json: + schema: + $ref: '#/components/schemas/MatrixClient' + 202: + description: Client updated (when update_client is true) + content: + application/json: + schema: + $ref: '#/components/schemas/MatrixClient' 401: $ref: '#/components/responses/Unauthorized' 500: @@ -641,6 +676,12 @@ components: access_token: type: string description: The Matrix access token for this client. + device_id: + type: string + description: The Matrix device ID corresponding to the access token. + fingerprint: + type: string + description: The encryption device fingerprint for verification. enabled: type: boolean example: true diff --git a/maubot/management/frontend/package.json b/maubot/management/frontend/package.json index e2d75ec..294c1f7 100644 --- a/maubot/management/frontend/package.json +++ b/maubot/management/frontend/package.json @@ -13,15 +13,15 @@ }, "homepage": ".", "dependencies": { - "node-sass": "^4.12.0", - "react": "^16.8.6", - "react-ace": "^9.0.0", - "react-contextmenu": "^2.11.0", - "react-dom": "^16.8.6", - "react-json-tree": "^0.11.2", - "react-router-dom": "^5.0.1", - "react-scripts": "3.4.1", - "react-select": "^3.0.4" + "react": "^17.0.2", + "react-ace": "^9.4.1", + "react-contextmenu": "^2.14.0", + "react-dom": "^17.0.2", + "react-json-tree": "^0.16.1", + "react-router-dom": "^5.3.0", + "react-scripts": "5.0.0", + "react-select": "^5.2.1", + "sass": "^1.34.1" }, "scripts": { "start": "react-scripts start", @@ -30,16 +30,11 @@ "eject": "react-scripts eject" }, "browserslist": [ - "last 5 firefox versions", - "last 3 and_ff versions", - "last 5 chrome versions", - "last 3 and_chr versions", - "last 2 safari versions", - "last 2 ios_saf versions" - ], - "devDependencies": { - "sass-lint": "^1.13.1", - "sass-lint-auto-fix": "^0.21.0", - "@babel/helper-call-delegate": "^7.10.4" - } + "last 2 firefox versions", + "last 2 and_ff versions", + "last 2 chrome versions", + "last 2 and_chr versions", + "last 1 safari versions", + "last 1 ios_saf versions" + ] } diff --git a/maubot/management/frontend/public/index.html b/maubot/management/frontend/public/index.html index 43255d8..d3679bf 100644 --- a/maubot/management/frontend/public/index.html +++ b/maubot/management/frontend/public/index.html @@ -1,6 +1,6 @@