diff --git a/Dockerfile b/Dockerfile
index bd27ebe..4fd6f2b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -15,7 +15,6 @@ RUN apk add --no-cache \
py3-attrs \
py3-bcrypt \
py3-cffi \
- py3-psycopg2 \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
@@ -49,7 +48,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \
- && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
&& pip3 install -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
&& apk del .build-deps
diff --git a/Dockerfile.ci b/Dockerfile.ci
index 7719d33..47655a0 100644
--- a/Dockerfile.ci
+++ b/Dockerfile.ci
@@ -10,7 +10,6 @@ RUN apk add --no-cache \
py3-attrs \
py3-bcrypt \
py3-cffi \
- py3-psycopg2 \
py3-ruamel.yaml \
py3-jinja2 \
py3-click \
@@ -43,7 +42,6 @@ COPY requirements.txt /opt/maubot/requirements.txt
COPY optional-requirements.txt /opt/maubot/optional-requirements.txt
WORKDIR /opt/maubot
RUN apk add --virtual .build-deps python3-dev build-base git \
- && sed -Ei 's/psycopg2-binary.+//' optional-requirements.txt \
&& pip3 install -r requirements.txt -r optional-requirements.txt \
dateparser langdetect python-gitlab pyquery cchardet semver tzlocal cssselect \
&& apk del .build-deps
diff --git a/alembic.ini b/alembic.ini
deleted file mode 100644
index 0d78e89..0000000
--- a/alembic.ini
+++ /dev/null
@@ -1,83 +0,0 @@
-# A generic, single database configuration.
-
-[alembic]
-# path to migration scripts
-script_location = alembic
-
-# template used to generate migration files
-# file_template = %%(rev)s_%%(slug)s
-
-# timezone to use when rendering the date
-# within the migration file as well as the filename.
-# string value is passed to dateutil.tz.gettz()
-# leave blank for localtime
-# timezone =
-
-# max length of characters to apply to the
-# "slug" field
-# truncate_slug_length = 40
-
-# set to 'true' to run the environment during
-# the 'revision' command, regardless of autogenerate
-# revision_environment = false
-
-# set to 'true' to allow .pyc and .pyo files without
-# a source .py file to be detected as revisions in the
-# versions/ directory
-# sourceless = false
-
-# version location specification; this defaults
-# to alembic/versions. When using multiple version
-# directories, initial revisions must be specified with --version-path
-# version_locations = %(here)s/bar %(here)s/bat alembic/versions
-
-# the output encoding used when revision files
-# are written from script.py.mako
-# output_encoding = utf-8
-
-
-[post_write_hooks]
-# post_write_hooks defines scripts or Python functions that are run
-# on newly generated revision scripts. See the documentation for further
-# detail and examples
-
-# format using "black" - use the console_scripts runner, against the "black" entrypoint
-# hooks=black
-# black.type=console_scripts
-# black.entrypoint=black
-# black.options=-l 79
-
-# Logging configuration
-[loggers]
-keys = root,sqlalchemy,alembic
-
-[handlers]
-keys = console
-
-[formatters]
-keys = generic
-
-[logger_root]
-level = WARN
-handlers = console
-qualname =
-
-[logger_sqlalchemy]
-level = WARN
-handlers =
-qualname = sqlalchemy.engine
-
-[logger_alembic]
-level = INFO
-handlers =
-qualname = alembic
-
-[handler_console]
-class = StreamHandler
-args = (sys.stderr,)
-level = NOTSET
-formatter = generic
-
-[formatter_generic]
-format = %(levelname)-5.5s [%(name)s] %(message)s
-datefmt = %H:%M:%S
diff --git a/alembic/README b/alembic/README
deleted file mode 100644
index 98e4f9c..0000000
--- a/alembic/README
+++ /dev/null
@@ -1 +0,0 @@
-Generic single-database configuration.
\ No newline at end of file
diff --git a/alembic/env.py b/alembic/env.py
deleted file mode 100644
index 9946810..0000000
--- a/alembic/env.py
+++ /dev/null
@@ -1,92 +0,0 @@
-from logging.config import fileConfig
-
-from sqlalchemy import engine_from_config, pool
-
-from alembic import context
-
-import sys
-from os.path import abspath, dirname
-
-sys.path.insert(0, dirname(dirname(abspath(__file__))))
-
-from mautrix.util.db import Base
-from maubot.config import Config
-from maubot import db
-
-# this is the Alembic Config object, which provides
-# access to the values within the .ini file in use.
-config = context.config
-
-maubot_config_path = context.get_x_argument(as_dictionary=True).get("config", "config.yaml")
-maubot_config = Config(maubot_config_path, None)
-maubot_config.load()
-config.set_main_option("sqlalchemy.url", maubot_config["database"].replace("%", "%%"))
-
-# Interpret the config file for Python logging.
-# This line sets up loggers basically.
-fileConfig(config.config_file_name)
-
-# add your model's MetaData object here
-# for 'autogenerate' support
-# from myapp import mymodel
-# target_metadata = mymodel.Base.metadata
-target_metadata = Base.metadata
-
-# other values from the config, defined by the needs of env.py,
-# can be acquired:
-# my_important_option = config.get_main_option("my_important_option")
-# ... etc.
-
-
-def run_migrations_offline():
- """Run migrations in 'offline' mode.
-
- This configures the context with just a URL
- and not an Engine, though an Engine is acceptable
- here as well. By skipping the Engine creation
- we don't even need a DBAPI to be available.
-
- Calls to context.execute() here emit the given string to the
- script output.
-
- """
- url = config.get_main_option("sqlalchemy.url")
- context.configure(
- url=url,
- target_metadata=target_metadata,
- literal_binds=True,
- dialect_opts={"paramstyle": "named"},
- render_as_batch=True,
- )
-
- with context.begin_transaction():
- context.run_migrations()
-
-
-def run_migrations_online():
- """Run migrations in 'online' mode.
-
- In this scenario we need to create an Engine
- and associate a connection with the context.
-
- """
- connectable = engine_from_config(
- config.get_section(config.config_ini_section),
- prefix="sqlalchemy.",
- poolclass=pool.NullPool,
- )
-
- with connectable.connect() as connection:
- context.configure(
- connection=connection, target_metadata=target_metadata,
- render_as_batch=True,
- )
-
- with context.begin_transaction():
- context.run_migrations()
-
-
-if context.is_offline_mode():
- run_migrations_offline()
-else:
- run_migrations_online()
diff --git a/alembic/script.py.mako b/alembic/script.py.mako
deleted file mode 100644
index 2c01563..0000000
--- a/alembic/script.py.mako
+++ /dev/null
@@ -1,24 +0,0 @@
-"""${message}
-
-Revision ID: ${up_revision}
-Revises: ${down_revision | comma,n}
-Create Date: ${create_date}
-
-"""
-from alembic import op
-import sqlalchemy as sa
-${imports if imports else ""}
-
-# revision identifiers, used by Alembic.
-revision = ${repr(up_revision)}
-down_revision = ${repr(down_revision)}
-branch_labels = ${repr(branch_labels)}
-depends_on = ${repr(depends_on)}
-
-
-def upgrade():
- ${upgrades if upgrades else "pass"}
-
-
-def downgrade():
- ${downgrades if downgrades else "pass"}
diff --git a/alembic/versions/4b93300852aa_add_device_id_to_clients.py b/alembic/versions/4b93300852aa_add_device_id_to_clients.py
deleted file mode 100644
index efc71cd..0000000
--- a/alembic/versions/4b93300852aa_add_device_id_to_clients.py
+++ /dev/null
@@ -1,32 +0,0 @@
-"""Add device_id to clients
-
-Revision ID: 4b93300852aa
-Revises: fccd1f95544d
-Create Date: 2020-07-11 15:49:38.831459
-
-"""
-from alembic import op
-import sqlalchemy as sa
-
-
-# revision identifiers, used by Alembic.
-revision = '4b93300852aa'
-down_revision = 'fccd1f95544d'
-branch_labels = None
-depends_on = None
-
-
-def upgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- with op.batch_alter_table('client', schema=None) as batch_op:
- batch_op.add_column(sa.Column('device_id', sa.String(length=255), nullable=True))
-
- # ### end Alembic commands ###
-
-
-def downgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- with op.batch_alter_table('client', schema=None) as batch_op:
- batch_op.drop_column('device_id')
-
- # ### end Alembic commands ###
diff --git a/alembic/versions/90aa88820eab_add_matrix_state_store.py b/alembic/versions/90aa88820eab_add_matrix_state_store.py
deleted file mode 100644
index 37a68eb..0000000
--- a/alembic/versions/90aa88820eab_add_matrix_state_store.py
+++ /dev/null
@@ -1,47 +0,0 @@
-"""Add Matrix state store
-
-Revision ID: 90aa88820eab
-Revises: 4b93300852aa
-Create Date: 2020-07-12 01:50:06.215623
-
-"""
-from alembic import op
-import sqlalchemy as sa
-
-from mautrix.client.state_store.sqlalchemy import SerializableType
-from mautrix.types import PowerLevelStateEventContent, RoomEncryptionStateEventContent
-
-
-# revision identifiers, used by Alembic.
-revision = '90aa88820eab'
-down_revision = '4b93300852aa'
-branch_labels = None
-depends_on = None
-
-
-def upgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- op.create_table('mx_room_state',
- sa.Column('room_id', sa.String(length=255), nullable=False),
- sa.Column('is_encrypted', sa.Boolean(), nullable=True),
- sa.Column('has_full_member_list', sa.Boolean(), nullable=True),
- sa.Column('encryption', SerializableType(RoomEncryptionStateEventContent), nullable=True),
- sa.Column('power_levels', SerializableType(PowerLevelStateEventContent), nullable=True),
- sa.PrimaryKeyConstraint('room_id')
- )
- op.create_table('mx_user_profile',
- sa.Column('room_id', sa.String(length=255), nullable=False),
- sa.Column('user_id', sa.String(length=255), nullable=False),
- sa.Column('membership', sa.Enum('JOIN', 'LEAVE', 'INVITE', 'BAN', 'KNOCK', name='membership'), nullable=False),
- sa.Column('displayname', sa.String(), nullable=True),
- sa.Column('avatar_url', sa.String(length=255), nullable=True),
- sa.PrimaryKeyConstraint('room_id', 'user_id')
- )
- # ### end Alembic commands ###
-
-
-def downgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('mx_user_profile')
- op.drop_table('mx_room_state')
- # ### end Alembic commands ###
diff --git a/alembic/versions/d295f8dcfa64_initial_revision.py b/alembic/versions/d295f8dcfa64_initial_revision.py
deleted file mode 100644
index ffa502f..0000000
--- a/alembic/versions/d295f8dcfa64_initial_revision.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""Initial revision
-
-Revision ID: d295f8dcfa64
-Revises:
-Create Date: 2019-09-27 00:21:02.527915
-
-"""
-from alembic import op
-import sqlalchemy as sa
-
-
-# revision identifiers, used by Alembic.
-revision = 'd295f8dcfa64'
-down_revision = None
-branch_labels = None
-depends_on = None
-
-
-def upgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- op.create_table('client',
- sa.Column('id', sa.String(length=255), nullable=False),
- sa.Column('homeserver', sa.String(length=255), nullable=False),
- sa.Column('access_token', sa.Text(), nullable=False),
- sa.Column('enabled', sa.Boolean(), nullable=False),
- sa.Column('next_batch', sa.String(length=255), nullable=False),
- sa.Column('filter_id', sa.String(length=255), nullable=False),
- sa.Column('sync', sa.Boolean(), nullable=False),
- sa.Column('autojoin', sa.Boolean(), nullable=False),
- sa.Column('displayname', sa.String(length=255), nullable=False),
- sa.Column('avatar_url', sa.String(length=255), nullable=False),
- sa.PrimaryKeyConstraint('id')
- )
- op.create_table('plugin',
- sa.Column('id', sa.String(length=255), nullable=False),
- sa.Column('type', sa.String(length=255), nullable=False),
- sa.Column('enabled', sa.Boolean(), nullable=False),
- sa.Column('primary_user', sa.String(length=255), nullable=False),
- sa.Column('config', sa.Text(), nullable=False),
- sa.ForeignKeyConstraint(['primary_user'], ['client.id'], onupdate='CASCADE', ondelete='RESTRICT'),
- sa.PrimaryKeyConstraint('id')
- )
- # ### end Alembic commands ###
-
-
-def downgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- op.drop_table('plugin')
- op.drop_table('client')
- # ### end Alembic commands ###
diff --git a/alembic/versions/fccd1f95544d_add_online_field_to_clients.py b/alembic/versions/fccd1f95544d_add_online_field_to_clients.py
deleted file mode 100644
index 1f7eabe..0000000
--- a/alembic/versions/fccd1f95544d_add_online_field_to_clients.py
+++ /dev/null
@@ -1,30 +0,0 @@
-"""Add online field to clients
-
-Revision ID: fccd1f95544d
-Revises: d295f8dcfa64
-Create Date: 2020-03-06 15:07:50.136644
-
-"""
-from alembic import op
-import sqlalchemy as sa
-
-
-# revision identifiers, used by Alembic.
-revision = 'fccd1f95544d'
-down_revision = 'd295f8dcfa64'
-branch_labels = None
-depends_on = None
-
-
-def upgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- with op.batch_alter_table("client") as batch_op:
- batch_op.add_column(sa.Column('online', sa.Boolean(), nullable=False, server_default=sa.sql.expression.true()))
- # ### end Alembic commands ###
-
-
-def downgrade():
- # ### commands auto generated by Alembic - please adjust! ###
- with op.batch_alter_table("client") as batch_op:
- batch_op.drop_column('online')
- # ### end Alembic commands ###
diff --git a/docker/run.sh b/docker/run.sh
index a9a40e1..9ca3a3f 100755
--- a/docker/run.sh
+++ b/docker/run.sh
@@ -1,7 +1,7 @@
#!/bin/sh
function fixperms {
- chown -R $UID:$GID /var/log /data /opt/maubot
+ chown -R $UID:$GID /var/log /data
}
function fixdefault {
diff --git a/maubot/__main__.py b/maubot/__main__.py
index a29f347..f7865ea 100644
--- a/maubot/__main__.py
+++ b/maubot/__main__.py
@@ -13,24 +13,37 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-import asyncio
+from __future__ import annotations
+import asyncio
+import sys
+
+from mautrix.util.async_db import Database, DatabaseException
from mautrix.util.program import Program
from .__meta__ import __version__
-from .client import Client, init as init_client_class
+from .client import Client
from .config import Config
-from .db import init as init_db
-from .instance import init as init_plugin_instance_class
+from .db import init as init_db, upgrade_table
+from .instance import PluginInstance
from .lib.future_awaitable import FutureAwaitable
+from .lib.state_store import PgStateStore
from .loader.zip import init as init_zip_loader
from .management.api import init as init_mgmt_api
from .server import MaubotServer
+try:
+ from mautrix.crypto.store import PgCryptoStore
+except ImportError:
+ PgCryptoStore = None
+
class Maubot(Program):
config: Config
server: MaubotServer
+ db: Database
+ crypto_db: Database | None
+ state_store: PgStateStore
config_class = Config
module = "maubot"
@@ -45,6 +58,19 @@ class Maubot(Program):
init(self.loop)
self.add_shutdown_actions(FutureAwaitable(stop_all))
+ def prepare_arg_parser(self) -> None:
+ super().prepare_arg_parser()
+ self.parser.add_argument(
+ "--ignore-unsupported-database",
+ action="store_true",
+ help="Run even if the database schema is too new",
+ )
+ self.parser.add_argument(
+ "--ignore-foreign-tables",
+ action="store_true",
+ help="Run even if the database contains tables from other programs (like Synapse)",
+ )
+
def prepare(self) -> None:
super().prepare()
@@ -52,21 +78,59 @@ class Maubot(Program):
self.prepare_log_websocket()
init_zip_loader(self.config)
- init_db(self.config)
- clients = init_client_class(self.config, self.loop)
- self.add_startup_actions(*(client.start() for client in clients))
+ self.db = Database.create(
+ self.config["database"],
+ upgrade_table=upgrade_table,
+ db_args=self.config["database_opts"],
+ owner_name=self.name,
+ ignore_foreign_tables=self.args.ignore_foreign_tables,
+ )
+ init_db(self.db)
+ if self.config["crypto_database"] == "default":
+ self.crypto_db = self.db
+ else:
+ self.crypto_db = Database.create(
+ self.config["crypto_database"],
+ upgrade_table=PgCryptoStore.upgrade_table,
+ ignore_foreign_tables=self.args.ignore_foreign_tables,
+ )
+ Client.init_cls(self)
+ PluginInstance.init_cls(self)
management_api = init_mgmt_api(self.config, self.loop)
self.server = MaubotServer(management_api, self.config, self.loop)
+ self.state_store = PgStateStore(self.db)
- plugins = init_plugin_instance_class(self.config, self.server, self.loop)
- for plugin in plugins:
- plugin.load()
+ async def start_db(self) -> None:
+ self.log.debug("Starting database...")
+ ignore_unsupported = self.args.ignore_unsupported_database
+ self.db.upgrade_table.allow_unsupported = ignore_unsupported
+ self.state_store.upgrade_table.allow_unsupported = ignore_unsupported
+ PgCryptoStore.upgrade_table.allow_unsupported = ignore_unsupported
+ try:
+ await self.db.start()
+ await self.state_store.upgrade_table.upgrade(self.db)
+ if self.crypto_db and self.crypto_db is not self.db:
+ await self.crypto_db.start()
+ else:
+ await PgCryptoStore.upgrade_table.upgrade(self.db)
+ except DatabaseException as e:
+ self.log.critical("Failed to initialize database", exc_info=e)
+ if e.explanation:
+ self.log.info(e.explanation)
+ sys.exit(25)
+
+ async def system_exit(self) -> None:
+ if hasattr(self, "db"):
+ self.log.trace("Stopping database due to SystemExit")
+ await self.db.stop()
async def start(self) -> None:
- if Client.crypto_db:
- self.log.debug("Starting client crypto database")
- await Client.crypto_db.start()
+ await self.start_db()
+ await asyncio.gather(*[plugin.load() async for plugin in PluginInstance.all()])
+ await asyncio.gather(*[client.start() async for client in Client.all()])
await super().start()
+ async for plugin in PluginInstance.all():
+ await plugin.load()
await self.server.start()
async def stop(self) -> None:
@@ -77,6 +141,7 @@ class Maubot(Program):
await asyncio.wait_for(self.server.stop(), 5)
except asyncio.TimeoutError:
self.log.warning("Stopping server timed out")
+ await self.db.stop()
Maubot().run()
diff --git a/maubot/__meta__.py b/maubot/__meta__.py
index 3ced358..690354d 100644
--- a/maubot/__meta__.py
+++ b/maubot/__meta__.py
@@ -1 +1 @@
-__version__ = "0.2.1"
+__version__ = "0.3.0+dev"
diff --git a/maubot/cli/commands/logs.py b/maubot/cli/commands/logs.py
index 98879ee..9a9c644 100644
--- a/maubot/cli/commands/logs.py
+++ b/maubot/cli/commands/logs.py
@@ -38,7 +38,7 @@ def logs(server: str, tail: int) -> None:
global history_count
history_count = tail
loop = asyncio.get_event_loop()
- future = asyncio.ensure_future(view_logs(server, token), loop=loop)
+ future = asyncio.create_task(view_logs(server, token), loop=loop)
try:
loop.run_until_complete(future)
except KeyboardInterrupt:
diff --git a/maubot/client.py b/maubot/client.py
index fa6c851..315f217 100644
--- a/maubot/client.py
+++ b/maubot/client.py
@@ -15,14 +15,14 @@
# along with this program. If not, see .
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable
+from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, Callable, cast
+from collections import defaultdict
import asyncio
import logging
from aiohttp import ClientSession
from mautrix.client import InternalEventType
-from mautrix.client.state_store.sqlalchemy import SQLStateStore as BaseSQLStateStore
from mautrix.errors import MatrixInvalidToken
from mautrix.types import (
ContentURI,
@@ -41,69 +41,110 @@ from mautrix.types import (
SyncToken,
UserID,
)
+from mautrix.util.async_getter_lock import async_getter_lock
+from mautrix.util.logging import TraceLogger
-from .db import DBClient
-from .lib.store_proxy import SyncStoreProxy
+from .db import Client as DBClient
from .matrix import MaubotMatrixClient
try:
- from mautrix.crypto import OlmMachine, PgCryptoStore, StateStore as CryptoStateStore
- from mautrix.util.async_db import Database as AsyncDatabase
-
- class SQLStateStore(BaseSQLStateStore, CryptoStateStore):
- pass
+ from mautrix.crypto import OlmMachine, PgCryptoStore
crypto_import_error = None
except ImportError as e:
- OlmMachine = CryptoStateStore = PgCryptoStore = AsyncDatabase = None
- SQLStateStore = BaseSQLStateStore
+ OlmMachine = PgCryptoStore = None
crypto_import_error = e
if TYPE_CHECKING:
- from .config import Config
+ from .__main__ import Maubot
from .instance import PluginInstance
-log = logging.getLogger("maubot.client")
-
-class Client:
- log: logging.Logger = None
- loop: asyncio.AbstractEventLoop = None
+class Client(DBClient):
+ maubot: "Maubot" = None
cache: dict[UserID, Client] = {}
+ _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
+ log: TraceLogger = logging.getLogger("maubot.client")
+
http_client: ClientSession = None
- global_state_store: BaseSQLStateStore | CryptoStateStore = SQLStateStore()
- crypto_db: AsyncDatabase | None = None
references: set[PluginInstance]
- db_instance: DBClient
client: MaubotMatrixClient
crypto: OlmMachine | None
crypto_store: PgCryptoStore | None
started: bool
+ sync_ok: bool
remote_displayname: str | None
remote_avatar_url: ContentURI | None
- def __init__(self, db_instance: DBClient) -> None:
- self.db_instance = db_instance
+ def __init__(
+ self,
+ id: UserID,
+ homeserver: str,
+ access_token: str,
+ device_id: DeviceID,
+ enabled: bool = False,
+ next_batch: SyncToken = "",
+ filter_id: FilterID = "",
+ sync: bool = True,
+ autojoin: bool = True,
+ online: bool = True,
+ displayname: str = "disable",
+ avatar_url: str = "disable",
+ ) -> None:
+ super().__init__(
+ id=id,
+ homeserver=homeserver,
+ access_token=access_token,
+ device_id=device_id,
+ enabled=bool(enabled),
+ next_batch=next_batch,
+ filter_id=filter_id,
+ sync=bool(sync),
+ autojoin=bool(autojoin),
+ online=bool(online),
+ displayname=displayname,
+ avatar_url=avatar_url,
+ )
+ self._postinited = False
+
+ def __hash__(self) -> int:
+ return hash(self.id)
+
+ @classmethod
+ def init_cls(cls, maubot: "Maubot") -> None:
+ cls.maubot = maubot
+
+ def _make_client(
+ self, homeserver: str | None = None, token: str | None = None, device_id: str | None = None
+ ) -> MaubotMatrixClient:
+ return MaubotMatrixClient(
+ mxid=self.id,
+ base_url=homeserver or self.homeserver,
+ token=token or self.access_token,
+ client_session=self.http_client,
+ log=self.log,
+ crypto_log=self.log.getChild("crypto"),
+ loop=self.maubot.loop,
+ device_id=device_id or self.device_id,
+ sync_store=self,
+ state_store=self.maubot.state_store,
+ )
+
+ def postinit(self) -> None:
+ if self._postinited:
+ raise RuntimeError("postinit() called twice")
+ self._postinited = True
self.cache[self.id] = self
- self.log = log.getChild(self.id)
+ self.log = self.log.getChild(self.id)
+ self.http_client = ClientSession(loop=self.maubot.loop)
self.references = set()
self.started = False
self.sync_ok = True
self.remote_displayname = None
self.remote_avatar_url = None
- self.client = MaubotMatrixClient(
- mxid=self.id,
- base_url=self.homeserver,
- token=self.access_token,
- client_session=self.http_client,
- log=self.log,
- loop=self.loop,
- device_id=self.device_id,
- sync_store=SyncStoreProxy(self.db_instance),
- state_store=self.global_state_store,
- )
+ self.client = self._make_client()
if self.enable_crypto:
self._prepare_crypto()
else:
@@ -118,6 +159,12 @@ class Client:
self.client.add_event_handler(InternalEventType.SYNC_ERRORED, self._set_sync_ok(False))
self.client.add_event_handler(InternalEventType.SYNC_SUCCESSFUL, self._set_sync_ok(True))
+ def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
+ async def handler(data: dict[str, Any]) -> None:
+ self.sync_ok = ok
+
+ return handler
+
@property
def enable_crypto(self) -> bool:
if not self.device_id:
@@ -131,16 +178,21 @@ class Client:
# Clear the stack trace after it's logged once to avoid spamming logs
crypto_import_error = None
return False
- elif not self.crypto_db:
+ elif not self.maubot.crypto_db:
self.log.warning("Client has device ID, but crypto database is not prepared")
return False
return True
def _prepare_crypto(self) -> None:
self.crypto_store = PgCryptoStore(
- account_id=self.id, pickle_key="mau.crypto", db=self.crypto_db
+ account_id=self.id, pickle_key="mau.crypto", db=self.maubot.crypto_db
+ )
+ self.crypto = OlmMachine(
+ self.client,
+ self.crypto_store,
+ self.maubot.state_store,
+ log=self.client.crypto_log,
)
- self.crypto = OlmMachine(self.client, self.crypto_store, self.global_state_store)
self.client.crypto = self.crypto
def _remove_crypto_event_handlers(self) -> None:
@@ -156,12 +208,6 @@ class Client:
for event_type, func in handlers:
self.client.remove_event_handler(event_type, func)
- def _set_sync_ok(self, ok: bool) -> Callable[[dict[str, Any]], Awaitable[None]]:
- async def handler(data: dict[str, Any]) -> None:
- self.sync_ok = ok
-
- return handler
-
async def start(self, try_n: int | None = 0) -> None:
try:
if try_n > 0:
@@ -196,47 +242,50 @@ class Client:
whoami = await self.client.whoami()
except MatrixInvalidToken as e:
self.log.error(f"Invalid token: {e}. Disabling client")
- self.db_instance.enabled = False
+ self.enabled = False
+ await self.update()
return
except Exception as e:
if try_n >= 8:
self.log.exception("Failed to get /account/whoami, disabling client")
- self.db_instance.enabled = False
+ self.enabled = False
+ await self.update()
else:
self.log.warning(
- f"Failed to get /account/whoami, " f"retrying in {(try_n + 1) * 10}s: {e}"
+ f"Failed to get /account/whoami, retrying in {(try_n + 1) * 10}s: {e}"
)
- _ = asyncio.ensure_future(self.start(try_n + 1), loop=self.loop)
+ _ = asyncio.create_task(self.start(try_n + 1))
return
if whoami.user_id != self.id:
self.log.error(f"User ID mismatch: expected {self.id}, but got {whoami.user_id}")
- self.db_instance.enabled = False
+ self.enabled = False
+ await self.update()
return
elif whoami.device_id and self.device_id and whoami.device_id != self.device_id:
self.log.error(
f"Device ID mismatch: expected {self.device_id}, " f"but got {whoami.device_id}"
)
- self.db_instance.enabled = False
+ self.enabled = False
+ await self.update()
return
if not self.filter_id:
- self.db_instance.edit(
- filter_id=await self.client.create_filter(
- Filter(
- room=RoomFilter(
- timeline=RoomEventFilter(
- limit=50,
- lazy_load_members=True,
- ),
- state=StateFilter(
- lazy_load_members=True,
- ),
+ self.filter_id = await self.client.create_filter(
+ Filter(
+ room=RoomFilter(
+ timeline=RoomEventFilter(
+ limit=50,
+ lazy_load_members=True,
),
- presence=EventFilter(
- not_types=[EventType.PRESENCE],
+ state=StateFilter(
+ lazy_load_members=True,
),
- )
+ ),
+ presence=EventFilter(
+ not_types=[EventType.PRESENCE],
+ ),
)
)
+ await self.update()
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
if self.avatar_url != "disable":
@@ -270,18 +319,13 @@ class Client:
if self.crypto:
await self.crypto_store.close()
- def clear_cache(self) -> None:
+ async def clear_cache(self) -> None:
self.stop_sync()
- self.db_instance.edit(filter_id="", next_batch="")
+ self.filter_id = FilterID("")
+ self.next_batch = SyncToken("")
+ await self.update()
self.start_sync()
- def delete(self) -> None:
- try:
- del self.cache[self.id]
- except KeyError:
- pass
- self.db_instance.delete()
-
def to_dict(self) -> dict:
return {
"id": self.id,
@@ -304,20 +348,6 @@ class Client:
"instances": [instance.to_dict() for instance in self.references],
}
- @classmethod
- def get(cls, user_id: UserID, db_instance: DBClient | None = None) -> Client | None:
- try:
- return cls.cache[user_id]
- except KeyError:
- db_instance = db_instance or DBClient.get(user_id)
- if not db_instance:
- return None
- return Client(db_instance)
-
- @classmethod
- def all(cls) -> Iterable[Client]:
- return (cls.get(user.id, user) for user in DBClient.all())
-
async def _handle_tombstone(self, evt: StateEvent) -> None:
if not evt.content.replacement_room:
self.log.info(f"{evt.room_id} tombstoned with no replacement, ignoring")
@@ -329,7 +359,7 @@ class Client:
if evt.state_key == self.id and evt.content.membership == Membership.INVITE:
await self.client.join_room(evt.room_id)
- async def update_started(self, started: bool) -> None:
+ async def update_started(self, started: bool | None) -> None:
if started is None or started == self.started:
return
if started:
@@ -337,23 +367,65 @@ class Client:
else:
await self.stop()
- async def update_displayname(self, displayname: str) -> None:
+ async def update_enabled(self, enabled: bool | None, save: bool = True) -> None:
+ if enabled is None or enabled == self.enabled:
+ return
+ self.enabled = enabled
+ if save:
+ await self.update()
+
+ async def update_displayname(self, displayname: str | None, save: bool = True) -> None:
if displayname is None or displayname == self.displayname:
return
- self.db_instance.displayname = displayname
+ self.displayname = displayname
if self.displayname != "disable":
await self.client.set_displayname(self.displayname)
else:
await self._update_remote_profile()
+ if save:
+ await self.update()
- async def update_avatar_url(self, avatar_url: ContentURI) -> None:
+ async def update_avatar_url(self, avatar_url: ContentURI, save: bool = True) -> None:
if avatar_url is None or avatar_url == self.avatar_url:
return
- self.db_instance.avatar_url = avatar_url
+ self.avatar_url = avatar_url
if self.avatar_url != "disable":
await self.client.set_avatar_url(self.avatar_url)
else:
await self._update_remote_profile()
+ if save:
+ await self.update()
+
+ async def update_sync(self, sync: bool | None, save: bool = True) -> None:
+ if sync is None or self.sync == sync:
+ return
+ self.sync = sync
+ if self.started:
+ if sync:
+ self.start_sync()
+ else:
+ self.stop_sync()
+ if save:
+ await self.update()
+
+ async def update_autojoin(self, autojoin: bool | None, save: bool = True) -> None:
+ if autojoin is None or autojoin == self.autojoin:
+ return
+ if autojoin:
+ self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
+ else:
+ self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
+ self.autojoin = autojoin
+ if save:
+ await self.update()
+
+ async def update_online(self, online: bool | None, save: bool = True) -> None:
+ if online is None or online == self.online:
+ return
+ self.client.presence = PresenceState.ONLINE if online else PresenceState.OFFLINE
+ self.online = online
+ if save:
+ await self.update()
async def update_access_details(
self,
@@ -373,22 +445,13 @@ class Client:
and device_id == self.device_id
):
return
- new_client = MaubotMatrixClient(
- mxid=self.id,
- base_url=homeserver or self.homeserver,
- token=access_token or self.access_token,
- loop=self.loop,
- device_id=device_id,
- client_session=self.http_client,
- log=self.log,
- state_store=self.global_state_store,
- )
+ new_client = self._make_client(homeserver, access_token, device_id)
whoami = await new_client.whoami()
if whoami.user_id != self.id:
raise ValueError(f"MXID mismatch: {whoami.user_id}")
elif whoami.device_id and device_id and whoami.device_id != device_id:
raise ValueError(f"Device ID mismatch: {whoami.device_id}")
- new_client.sync_store = SyncStoreProxy(self.db_instance)
+ new_client.sync_store = self
self.stop_sync()
# TODO this event handler transfer is pretty hacky
@@ -398,9 +461,9 @@ class Client:
new_client.global_event_handlers = self.client.global_event_handlers
self.client = new_client
- self.db_instance.homeserver = homeserver
- self.db_instance.access_token = access_token
- self.db_instance.device_id = device_id
+ self.homeserver = homeserver
+ self.access_token = access_token
+ self.device_id = device_id
if self.enable_crypto:
self._prepare_crypto()
await self._start_crypto()
@@ -413,97 +476,53 @@ class Client:
profile = await self.client.get_profile(self.id)
self.remote_displayname, self.remote_avatar_url = profile.displayname, profile.avatar_url
- # region Properties
+ async def delete(self) -> None:
+ try:
+ del self.cache[self.id]
+ except KeyError:
+ pass
+ await super().delete()
- @property
- def id(self) -> UserID:
- return self.db_instance.id
+ @classmethod
+ @async_getter_lock
+ async def get(
+ cls,
+ user_id: UserID,
+ *,
+ homeserver: str | None = None,
+ access_token: str | None = None,
+ device_id: DeviceID | None = None,
+ ) -> Client | None:
+ try:
+ return cls.cache[user_id]
+ except KeyError:
+ pass
- @property
- def homeserver(self) -> str:
- return self.db_instance.homeserver
+ user = cast(cls, await super().get(user_id))
+ if user is not None:
+ user.postinit()
+ return user
- @property
- def access_token(self) -> str:
- return self.db_instance.access_token
+ if homeserver and access_token:
+ user = cls(
+ user_id,
+ homeserver=homeserver,
+ access_token=access_token,
+ device_id=device_id or "",
+ )
+ await user.insert()
+ user.postinit()
+ return user
- @property
- def device_id(self) -> DeviceID:
- return self.db_instance.device_id
+ return None
- @property
- def enabled(self) -> bool:
- return self.db_instance.enabled
-
- @enabled.setter
- def enabled(self, value: bool) -> None:
- self.db_instance.enabled = value
-
- @property
- def next_batch(self) -> SyncToken:
- return self.db_instance.next_batch
-
- @property
- def filter_id(self) -> FilterID:
- return self.db_instance.filter_id
-
- @property
- def sync(self) -> bool:
- return self.db_instance.sync
-
- @sync.setter
- def sync(self, value: bool) -> None:
- if value == self.db_instance.sync:
- return
- self.db_instance.sync = value
- if self.started:
- if value:
- self.start_sync()
- else:
- self.stop_sync()
-
- @property
- def autojoin(self) -> bool:
- return self.db_instance.autojoin
-
- @autojoin.setter
- def autojoin(self, value: bool) -> None:
- if value == self.db_instance.autojoin:
- return
- if value:
- self.client.add_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
- else:
- self.client.remove_event_handler(EventType.ROOM_MEMBER, self._handle_invite)
- self.db_instance.autojoin = value
-
- @property
- def online(self) -> bool:
- return self.db_instance.online
-
- @online.setter
- def online(self, value: bool) -> None:
- self.client.presence = PresenceState.ONLINE if value else PresenceState.OFFLINE
- self.db_instance.online = value
-
- @property
- def displayname(self) -> str:
- return self.db_instance.displayname
-
- @property
- def avatar_url(self) -> ContentURI:
- return self.db_instance.avatar_url
-
- # endregion
-
-
-def init(config: "Config", loop: asyncio.AbstractEventLoop) -> Iterable[Client]:
- Client.http_client = ClientSession(loop=loop)
- Client.loop = loop
-
- if OlmMachine:
- db_url = config["crypto_database"]
- if db_url == "default":
- db_url = config["database"]
- Client.crypto_db = AsyncDatabase.create(db_url, upgrade_table=PgCryptoStore.upgrade_table)
-
- return Client.all()
+ @classmethod
+ async def all(cls) -> AsyncGenerator[Client, None]:
+ users = await super().all()
+ user: cls
+ for user in users:
+ try:
+ yield cls.cache[user.id]
+ except KeyError:
+ user.postinit()
+ yield user
diff --git a/maubot/db.py b/maubot/db.py
deleted file mode 100644
index 9f388d3..0000000
--- a/maubot/db.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# maubot - A plugin-based Matrix bot system.
-# Copyright (C) 2022 Tulir Asokan
-#
-# This program is free software: you can redistribute it and/or modify
-# it under the terms of the GNU Affero General Public License as published by
-# the Free Software Foundation, either version 3 of the License, or
-# (at your option) any later version.
-#
-# This program is distributed in the hope that it will be useful,
-# but WITHOUT ANY WARRANTY; without even the implied warranty of
-# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-# GNU Affero General Public License for more details.
-#
-# You should have received a copy of the GNU Affero General Public License
-# along with this program. If not, see .
-from typing import Iterable, Optional
-import logging
-import sys
-
-from sqlalchemy import Boolean, Column, ForeignKey, String, Text
-from sqlalchemy.engine.base import Engine
-import sqlalchemy as sql
-
-from mautrix.client.state_store.sqlalchemy import RoomState, UserProfile
-from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
-from mautrix.util.db import Base
-
-from .config import Config
-
-
-class DBPlugin(Base):
- __tablename__ = "plugin"
-
- id: str = Column(String(255), primary_key=True)
- type: str = Column(String(255), nullable=False)
- enabled: bool = Column(Boolean, nullable=False, default=False)
- primary_user: UserID = Column(
- String(255),
- ForeignKey("client.id", onupdate="CASCADE", ondelete="RESTRICT"),
- nullable=False,
- )
- config: str = Column(Text, nullable=False, default="")
-
- @classmethod
- def all(cls) -> Iterable["DBPlugin"]:
- return cls._select_all()
-
- @classmethod
- def get(cls, id: str) -> Optional["DBPlugin"]:
- return cls._select_one_or_none(cls.c.id == id)
-
-
-class DBClient(Base):
- __tablename__ = "client"
-
- id: UserID = Column(String(255), primary_key=True)
- homeserver: str = Column(String(255), nullable=False)
- access_token: str = Column(Text, nullable=False)
- device_id: DeviceID = Column(String(255), nullable=True)
- enabled: bool = Column(Boolean, nullable=False, default=False)
-
- next_batch: SyncToken = Column(String(255), nullable=False, default="")
- filter_id: FilterID = Column(String(255), nullable=False, default="")
-
- sync: bool = Column(Boolean, nullable=False, default=True)
- autojoin: bool = Column(Boolean, nullable=False, default=True)
- online: bool = Column(Boolean, nullable=False, default=True)
-
- displayname: str = Column(String(255), nullable=False, default="")
- avatar_url: ContentURI = Column(String(255), nullable=False, default="")
-
- @classmethod
- def all(cls) -> Iterable["DBClient"]:
- return cls._select_all()
-
- @classmethod
- def get(cls, id: str) -> Optional["DBClient"]:
- return cls._select_one_or_none(cls.c.id == id)
-
-
-def init(config: Config) -> Engine:
- db = sql.create_engine(config["database"])
- Base.metadata.bind = db
-
- for table in (DBPlugin, DBClient, RoomState, UserProfile):
- table.bind(db)
-
- if not db.has_table("alembic_version"):
- log = logging.getLogger("maubot.db")
-
- if db.has_table("client") and db.has_table("plugin"):
- log.warning(
- "alembic_version table not found, but client and plugin tables found. "
- "Assuming pre-Alembic database and inserting version."
- )
- db.execute(
- "CREATE TABLE IF NOT EXISTS alembic_version ("
- " version_num VARCHAR(32) PRIMARY KEY"
- ");"
- )
- db.execute("INSERT INTO alembic_version VALUES ('d295f8dcfa64');")
- else:
- log.critical(
- "alembic_version table not found. " "Did you forget to `alembic upgrade head`?"
- )
- sys.exit(10)
-
- return db
diff --git a/maubot/db/__init__.py b/maubot/db/__init__.py
new file mode 100644
index 0000000..d6aeb09
--- /dev/null
+++ b/maubot/db/__init__.py
@@ -0,0 +1,13 @@
+from mautrix.util.async_db import Database
+
+from .client import Client
+from .instance import Instance
+from .upgrade import upgrade_table
+
+
+def init(db: Database) -> None:
+ for table in (Client, Instance):
+ table.db = db
+
+
+__all__ = ["upgrade_table", "init", "Client", "Instance"]
diff --git a/maubot/db/client.py b/maubot/db/client.py
new file mode 100644
index 0000000..52f3a20
--- /dev/null
+++ b/maubot/db/client.py
@@ -0,0 +1,114 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2022 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
+
+from asyncpg import Record
+from attr import dataclass
+
+from mautrix.client import SyncStore
+from mautrix.types import ContentURI, DeviceID, FilterID, SyncToken, UserID
+from mautrix.util.async_db import Database
+
+fake_db = Database.create("") if TYPE_CHECKING else None
+
+
+@dataclass
+class Client(SyncStore):
+ db: ClassVar[Database] = fake_db
+
+ id: UserID
+ homeserver: str
+ access_token: str
+ device_id: DeviceID
+ enabled: bool
+
+ next_batch: SyncToken
+ filter_id: FilterID
+
+ sync: bool
+ autojoin: bool
+ online: bool
+
+ displayname: str
+ avatar_url: ContentURI
+
+ @classmethod
+ def _from_row(cls, row: Record | None) -> Client | None:
+ if row is None:
+ return None
+ return cls(**row)
+
+ _columns = (
+ "id, homeserver, access_token, device_id, enabled, next_batch, filter_id, "
+ "sync, autojoin, online, displayname, avatar_url"
+ )
+
+ @property
+ def _values(self):
+ return (
+ self.id,
+ self.homeserver,
+ self.access_token,
+ self.device_id,
+ self.enabled,
+ self.next_batch,
+ self.filter_id,
+ self.sync,
+ self.autojoin,
+ self.online,
+ self.displayname,
+ self.avatar_url,
+ )
+
+ @classmethod
+ async def all(cls) -> list[Client]:
+ rows = await cls.db.fetch(f"SELECT {cls._columns} FROM client")
+ return [cls._from_row(row) for row in rows]
+
+ @classmethod
+ async def get(cls, id: str) -> Client | None:
+ q = f"SELECT {cls._columns} FROM client WHERE id=$1"
+ return cls._from_row(await cls.db.fetchrow(q, id))
+
+ async def insert(self) -> None:
+ q = """
+ INSERT INTO client (
+ id, homeserver, access_token, device_id, enabled, next_batch, filter_id,
+ sync, autojoin, online, displayname, avatar_url
+ ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
+ """
+ await self.db.execute(q, *self._values)
+
+ async def put_next_batch(self, next_batch: SyncToken) -> None:
+ await self.db.execute("UPDATE client SET next_batch=$1 WHERE id=$2", next_batch, self.id)
+ self.next_batch = next_batch
+
+ async def get_next_batch(self) -> SyncToken:
+ return self.next_batch
+
+ async def update(self) -> None:
+ q = """
+ UPDATE client SET homeserver=$2, access_token=$3, device_id=$4, enabled=$5,
+ next_batch=$6, filter_id=$7, sync=$8, autojoin=$9, online=$10,
+ displayname=$11, avatar_url=$12
+ WHERE id=$1
+ """
+ await self.db.execute(q, *self._values)
+
+ async def delete(self) -> None:
+ await self.db.execute("DELETE FROM client WHERE id=$1", self.id)
diff --git a/maubot/db/instance.py b/maubot/db/instance.py
new file mode 100644
index 0000000..dff7064
--- /dev/null
+++ b/maubot/db/instance.py
@@ -0,0 +1,75 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2022 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, ClassVar
+
+from asyncpg import Record
+from attr import dataclass
+
+from mautrix.types import UserID
+from mautrix.util.async_db import Database
+
+fake_db = Database.create("") if TYPE_CHECKING else None
+
+
+@dataclass
+class Instance:
+ db: ClassVar[Database] = fake_db
+
+ id: str
+ type: str
+ enabled: bool
+ primary_user: UserID
+ config_str: str
+
+ @classmethod
+ def _from_row(cls, row: Record | None) -> Instance | None:
+ if row is None:
+ return None
+ return cls(**row)
+
+ @classmethod
+ async def all(cls) -> list[Instance]:
+ rows = await cls.db.fetch("SELECT id, type, enabled, primary_user, config FROM instance")
+ return [cls._from_row(row) for row in rows]
+
+ @classmethod
+ async def get(cls, id: str) -> Instance | None:
+ q = "SELECT id, type, enabled, primary_user, config FROM instance WHERE id=$1"
+ return cls._from_row(await cls.db.fetchrow(q, id))
+
+ async def update_id(self, new_id: str) -> None:
+ await self.db.execute("UPDATE instance SET id=$1 WHERE id=$2", new_id, self.id)
+ self.id = new_id
+
+ @property
+ def _values(self):
+ return self.id, self.type, self.enabled, self.primary_user, self.config_str
+
+ async def insert(self) -> None:
+ q = (
+ "INSERT INTO instance (id, type, enabled, primary_user, config) "
+ "VALUES ($1, $2, $3, $4, $5)"
+ )
+ await self.db.execute(q, *self._values)
+
+ async def update(self) -> None:
+ q = "UPDATE instance SET type=$2, enabled=$3, primary_user=$4, config=$5 WHERE id=$1"
+ await self.db.execute(q, *self._values)
+
+ async def delete(self) -> None:
+ await self.db.execute("DELETE FROM instance WHERE id=$1", self.id)
diff --git a/maubot/db/upgrade/__init__.py b/maubot/db/upgrade/__init__.py
new file mode 100644
index 0000000..146e713
--- /dev/null
+++ b/maubot/db/upgrade/__init__.py
@@ -0,0 +1,5 @@
+from mautrix.util.async_db import UpgradeTable
+
+upgrade_table = UpgradeTable()
+
+from . import v01_initial_revision
diff --git a/maubot/db/upgrade/v01_initial_revision.py b/maubot/db/upgrade/v01_initial_revision.py
new file mode 100644
index 0000000..2da8aff
--- /dev/null
+++ b/maubot/db/upgrade/v01_initial_revision.py
@@ -0,0 +1,136 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2022 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from __future__ import annotations
+
+from mautrix.util.async_db import Connection, Scheme
+
+from . import upgrade_table
+
+legacy_version_query = "SELECT version_num FROM alembic_version"
+last_legacy_version = "90aa88820eab"
+
+
+@upgrade_table.register(description="Initial asyncpg revision")
+async def upgrade_v1(conn: Connection, scheme: Scheme) -> None:
+ if await conn.table_exists("alembic_version"):
+ await migrate_legacy_to_v1(conn, scheme)
+ else:
+ return await create_v1_tables(conn)
+
+
+async def create_v1_tables(conn: Connection) -> None:
+ await conn.execute(
+ """CREATE TABLE client (
+ id TEXT PRIMARY KEY,
+ homeserver TEXT NOT NULL,
+ access_token TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ enabled BOOLEAN NOT NULL,
+
+ next_batch TEXT NOT NULL,
+ filter_id TEXT NOT NULL,
+
+ sync BOOLEAN NOT NULL,
+ autojoin BOOLEAN NOT NULL,
+ online BOOLEAN NOT NULL,
+
+ displayname TEXT NOT NULL,
+ avatar_url TEXT NOT NULL
+ )"""
+ )
+ await conn.execute(
+ """CREATE TABLE instance (
+ id TEXT PRIMARY KEY,
+ type TEXT NOT NULL,
+ enabled BOOLEAN NOT NULL,
+ primary_user TEXT NOT NULL,
+ config TEXT NOT NULL,
+ FOREIGN KEY (primary_user) REFERENCES client(id) ON DELETE RESTRICT ON UPDATE CASCADE
+ )"""
+ )
+
+
+async def migrate_legacy_to_v1(conn: Connection, scheme: Scheme) -> None:
+ legacy_version = await conn.fetchval(legacy_version_query)
+ if legacy_version != last_legacy_version:
+ raise RuntimeError(
+ "Legacy database is not on last version. "
+ "Please upgrade the old database with alembic or drop it completely first."
+ )
+ await conn.execute("ALTER TABLE plugin RENAME TO instance")
+ await update_state_store(conn, scheme)
+ if scheme != Scheme.SQLITE:
+ await varchar_to_text(conn)
+ await conn.execute("DROP TABLE alembic_version")
+
+
+async def update_state_store(conn: Connection, scheme: Scheme) -> None:
+ # The Matrix state store already has more or less the correct schema, so set the version
+ await conn.execute("CREATE TABLE mx_version (version INTEGER PRIMARY KEY)")
+ await conn.execute("INSERT INTO mx_version (version) VALUES (2)")
+ if scheme != Scheme.SQLITE:
+ # Remove old uppercase membership type and recreate it as lowercase
+ await conn.execute("ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE TEXT")
+ await conn.execute("DROP TYPE IF EXISTS membership")
+ await conn.execute(
+ "CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock')"
+ )
+ await conn.execute(
+ "ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership "
+ "USING LOWER(membership)::membership"
+ )
+ else:
+ # Recreate table to remove CHECK constraint and lowercase everything
+ await conn.execute(
+ """CREATE TABLE new_mx_user_profile (
+ room_id TEXT,
+ user_id TEXT,
+ membership TEXT NOT NULL
+ CHECK (membership IN ('join', 'leave', 'invite', 'ban', 'knock')),
+ displayname TEXT,
+ avatar_url TEXT,
+ PRIMARY KEY (room_id, user_id)
+ )"""
+ )
+ await conn.execute(
+ """
+ INSERT INTO new_mx_user_profile (room_id, user_id, membership, displayname, avatar_url)
+ SELECT room_id, user_id, LOWER(membership), displayname, avatar_url
+ FROM mx_user_profile
+ """
+ )
+ await conn.execute("DROP TABLE mx_user_profile")
+ await conn.execute("ALTER TABLE new_mx_user_profile RENAME TO mx_user_profile")
+
+
+async def varchar_to_text(conn: Connection) -> None:
+ columns_to_adjust = {
+ "client": (
+ "id",
+ "homeserver",
+ "device_id",
+ "next_batch",
+ "filter_id",
+ "displayname",
+ "avatar_url",
+ ),
+ "instance": ("id", "type", "primary_user"),
+ "mx_room_state": ("room_id",),
+ "mx_user_profile": ("room_id", "user_id", "displayname", "avatar_url"),
+ }
+ for table, columns in columns_to_adjust.items():
+ for column in columns:
+ await conn.execute(f'ALTER TABLE "{table}" ALTER COLUMN {column} TYPE TEXT')
diff --git a/maubot/example-config.yaml b/maubot/example-config.yaml
index eb9bfe2..0f82e12 100644
--- a/maubot/example-config.yaml
+++ b/maubot/example-config.yaml
@@ -6,9 +6,7 @@
database: sqlite:///maubot.db
# Separate database URL for the crypto database. "default" means use the same database as above.
-# Due to concurrency issues, you should use a separate file when using SQLite rather than the same as above.
-# When using postgres, using the same database for both is safe.
-crypto_database: sqlite:///crypto.db
+crypto_database: default
plugin_directories:
# The directory where uploaded new plugins should be stored.
diff --git a/maubot/instance.py b/maubot/instance.py
index 7d7900b..d615a72 100644
--- a/maubot/instance.py
+++ b/maubot/instance.py
@@ -15,8 +15,10 @@
# along with this program. If not, see .
from __future__ import annotations
-from typing import TYPE_CHECKING, Iterable
-from asyncio import AbstractEventLoop
+from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
+from collections import defaultdict
+import asyncio
+import inspect
import io
import logging
import os.path
@@ -26,16 +28,17 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from mautrix.types import UserID
+from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from .client import Client
-from .config import Config
-from .db import DBPlugin
+from .db import Instance as DBInstance
from .loader import PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
if TYPE_CHECKING:
- from .server import MaubotServer, PluginWebApp
+ from .__main__ import Maubot
+ from .server import PluginWebApp
log = logging.getLogger("maubot.instance")
@@ -44,29 +47,42 @@ yaml.indent(4)
yaml.width = 200
-class PluginInstance:
- webserver: MaubotServer = None
- mb_config: Config = None
- loop: AbstractEventLoop = None
+class PluginInstance(DBInstance):
+ maubot: "Maubot" = None
cache: dict[str, PluginInstance] = {}
plugin_directories: list[str] = []
+ _async_get_locks: dict[Any, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
log: logging.Logger
- loader: PluginLoader
- client: Client
- plugin: Plugin
- config: BaseProxyConfig
+ loader: PluginLoader | None
+ client: Client | None
+ plugin: Plugin | None
+ config: BaseProxyConfig | None
base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None
- inst_db: sql.engine.Engine
- inst_db_tables: dict[str, sql.Table]
+ inst_db: sql.engine.Engine | None
+ inst_db_tables: dict[str, sql.Table] | None
inst_webapp: PluginWebApp | None
inst_webapp_url: str | None
started: bool
- def __init__(self, db_instance: DBPlugin):
- self.db_instance = db_instance
+ def __init__(
+ self, id: str, type: str, enabled: bool, primary_user: UserID, config: str = ""
+ ) -> None:
+ super().__init__(
+ id=id, type=type, enabled=bool(enabled), primary_user=primary_user, config_str=config
+ )
+
+ def __hash__(self) -> int:
+ return hash(self.id)
+
+ @classmethod
+ def init_cls(cls, maubot: "Maubot") -> None:
+ cls.maubot = maubot
+
+ def postinit(self) -> None:
self.log = log.getChild(self.id)
+ self.cache[self.id] = self
self.config = None
self.started = False
self.loader = None
@@ -78,7 +94,6 @@ class PluginInstance:
self.inst_webapp_url = None
self.base_cfg = None
self.base_cfg_str = None
- self.cache[self.id] = self
def to_dict(self) -> dict:
return {
@@ -87,10 +102,10 @@ class PluginInstance:
"enabled": self.enabled,
"started": self.started,
"primary_user": self.primary_user,
- "config": self.db_instance.config,
+ "config": self.config_str,
"base_config": self.base_cfg_str,
"database": (
- self.inst_db is not None and self.mb_config["api_features.instance_database"]
+ self.inst_db is not None and self.maubot.config["api_features.instance_database"]
),
}
@@ -101,19 +116,19 @@ class PluginInstance:
self.inst_db_tables = metadata.tables
return self.inst_db_tables
- def load(self) -> bool:
+ async def load(self) -> bool:
if not self.loader:
try:
self.loader = PluginLoader.find(self.type)
except KeyError:
self.log.error(f"Failed to find loader for type {self.type}")
- self.db_instance.enabled = False
+ await self.update_enabled(False)
return False
if not self.client:
- self.client = Client.get(self.primary_user)
+ self.client = await Client.get(self.primary_user)
if not self.client:
self.log.error(f"Failed to get client for user {self.primary_user}")
- self.db_instance.enabled = False
+ await self.update_enabled(False)
return False
if self.loader.meta.database:
self.enable_database()
@@ -125,18 +140,18 @@ class PluginInstance:
return True
def enable_webapp(self) -> None:
- self.inst_webapp, self.inst_webapp_url = self.webserver.get_instance_subapp(self.id)
+ self.inst_webapp, self.inst_webapp_url = self.maubot.server.get_instance_subapp(self.id)
def disable_webapp(self) -> None:
- self.webserver.remove_instance_webapp(self.id)
+ self.maubot.server.remove_instance_webapp(self.id)
self.inst_webapp = None
self.inst_webapp_url = None
def enable_database(self) -> None:
- db_path = os.path.join(self.mb_config["plugin_directories.db"], self.id)
+ db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
- def delete(self) -> None:
+ async def delete(self) -> None:
if self.loader is not None:
self.loader.references.remove(self)
if self.client is not None:
@@ -145,23 +160,23 @@ class PluginInstance:
del self.cache[self.id]
except KeyError:
pass
- self.db_instance.delete()
+ await super().delete()
if self.inst_db:
self.inst_db.dispose()
ZippedPluginLoader.trash(
- os.path.join(self.mb_config["plugin_directories.db"], f"{self.id}.db"),
+ os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted",
)
if self.inst_webapp:
self.disable_webapp()
def load_config(self) -> CommentedMap:
- return yaml.load(self.db_instance.config)
+ return yaml.load(self.config_str)
def save_config(self, data: RecursiveDict[CommentedMap]) -> None:
buf = io.StringIO()
yaml.dump(data, buf)
- self.db_instance.config = buf.getvalue()
+ self.config_str = buf.getvalue()
async def start(self) -> None:
if self.started:
@@ -172,7 +187,7 @@ class PluginInstance:
return
if not self.client or not self.loader:
self.log.warning("Missing plugin instance dependencies, attempting to load...")
- if not self.load():
+ if not await self.load():
return
cls = await self.loader.load()
if self.loader.meta.webapp and self.inst_webapp is None:
@@ -205,7 +220,7 @@ class PluginInstance:
self.config = config_class(self.load_config, base_cfg_func, self.save_config)
self.plugin = cls(
client=self.client.client,
- loop=self.loop,
+ loop=self.maubot.loop,
http=self.client.http_client,
instance_id=self.id,
log=self.log,
@@ -219,7 +234,7 @@ class PluginInstance:
await self.plugin.internal_start()
except Exception:
self.log.exception("Failed to start instance")
- self.db_instance.enabled = False
+ await self.update_enabled(False)
return
self.started = True
self.inst_db_tables = None
@@ -241,60 +256,51 @@ class PluginInstance:
self.plugin = None
self.inst_db_tables = None
- @classmethod
- def get(cls, instance_id: str, db_instance: DBPlugin | None = None) -> PluginInstance | None:
- try:
- return cls.cache[instance_id]
- except KeyError:
- db_instance = db_instance or DBPlugin.get(instance_id)
- if not db_instance:
- return None
- return PluginInstance(db_instance)
+ async def update_id(self, new_id: str | None) -> None:
+ if new_id is not None and new_id.lower() != self.id:
+ await super().update_id(new_id.lower())
- @classmethod
- def all(cls) -> Iterable[PluginInstance]:
- return (cls.get(plugin.id, plugin) for plugin in DBPlugin.all())
-
- def update_id(self, new_id: str) -> None:
- if new_id is not None and new_id != self.id:
- self.db_instance.id = new_id.lower()
-
- def update_config(self, config: str) -> None:
- if not config or self.db_instance.config == config:
+ async def update_config(self, config: str | None) -> None:
+ if config is None or self.config_str == config:
return
- self.db_instance.config = config
+ self.config_str = config
if self.started and self.plugin is not None:
- self.plugin.on_external_config_update()
+ res = self.plugin.on_external_config_update()
+ if inspect.isawaitable(res):
+ await res
+ await self.update()
- async def update_primary_user(self, primary_user: UserID) -> bool:
- if not primary_user or primary_user == self.primary_user:
+ async def update_primary_user(self, primary_user: UserID | None) -> bool:
+ if primary_user is None or primary_user == self.primary_user:
return True
- client = Client.get(primary_user)
+ client = await Client.get(primary_user)
if not client:
return False
await self.stop()
- self.db_instance.primary_user = client.id
+ self.primary_user = client.id
if self.client:
self.client.references.remove(self)
self.client = client
self.client.references.add(self)
+ await self.update()
await self.start()
self.log.debug(f"Primary user switched to {self.client.id}")
return True
- async def update_type(self, type: str) -> bool:
- if not type or type == self.type:
+ async def update_type(self, type: str | None) -> bool:
+ if type is None or type == self.type:
return True
try:
loader = PluginLoader.find(type)
except KeyError:
return False
await self.stop()
- self.db_instance.type = loader.meta.id
+ self.type = loader.meta.id
if self.loader:
self.loader.references.remove(self)
self.loader = loader
self.loader.references.add(self)
+ await self.update()
await self.start()
self.log.debug(f"Type switched to {self.loader.meta.id}")
return True
@@ -303,39 +309,41 @@ class PluginInstance:
if started is not None and started != self.started:
await (self.start() if started else self.stop())
- def update_enabled(self, enabled: bool) -> None:
+ async def update_enabled(self, enabled: bool) -> None:
if enabled is not None and enabled != self.enabled:
- self.db_instance.enabled = enabled
+ self.enabled = enabled
+ await self.update()
- # region Properties
+ @classmethod
+ @async_getter_lock
+ async def get(
+ cls, instance_id: str, *, type: str | None = None, primary_user: UserID | None = None
+ ) -> PluginInstance | None:
+ try:
+ return cls.cache[instance_id]
+ except KeyError:
+ pass
- @property
- def id(self) -> str:
- return self.db_instance.id
+ instance = cast(cls, await super().get(instance_id))
+ if instance is not None:
+ instance.postinit()
+ return instance
- @id.setter
- def id(self, value: str) -> None:
- self.db_instance.id = value
+ if type and primary_user:
+ instance = cls(instance_id, type=type, enabled=True, primary_user=primary_user)
+ await instance.insert()
+ instance.postinit()
+ return instance
- @property
- def type(self) -> str:
- return self.db_instance.type
+ return None
- @property
- def enabled(self) -> bool:
- return self.db_instance.enabled
-
- @property
- def primary_user(self) -> UserID:
- return self.db_instance.primary_user
-
- # endregion
-
-
-def init(
- config: Config, webserver: MaubotServer, loop: AbstractEventLoop
-) -> Iterable[PluginInstance]:
- PluginInstance.mb_config = config
- PluginInstance.loop = loop
- PluginInstance.webserver = webserver
- return PluginInstance.all()
+ @classmethod
+ async def all(cls) -> AsyncGenerator[PluginInstance, None]:
+ instances = await super().all()
+ instance: PluginInstance
+ for instance in instances:
+ try:
+ yield cls.cache[instance.id]
+ except KeyError:
+ instance.postinit()
+ yield instance
diff --git a/maubot/lib/color_log.py b/maubot/lib/color_log.py
index 104e9f7..4fb94e0 100644
--- a/maubot/lib/color_log.py
+++ b/maubot/lib/color_log.py
@@ -28,14 +28,19 @@ LOADER_COLOR = PREFIX + "36m" # blue
class ColorFormatter(BaseColorFormatter):
def _color_name(self, module: str) -> str:
client = "maubot.client"
- if module.startswith(client):
- return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module[len(client) + 1:]}{RESET}"
+ if module.startswith(client + "."):
+ suffix = ""
+ if module.endswith(".crypto"):
+ suffix = f".{MAU_COLOR}crypto{RESET}"
+ module = module[: -len(".crypto")]
+ module = module[len(client) + 1 :]
+ return f"{MAU_COLOR}{client}{RESET}.{MXID_COLOR}{module}{RESET}{suffix}"
instance = "maubot.instance"
- if module.startswith(instance):
+ if module.startswith(instance + "."):
return f"{MAU_COLOR}{instance}{RESET}.{INST_COLOR}{module[len(instance) + 1:]}{RESET}"
loader = "maubot.loader"
- if module.startswith(loader):
+ if module.startswith(loader + "."):
return f"{MAU_COLOR}{instance}{RESET}.{LOADER_COLOR}{module[len(loader) + 1:]}{RESET}"
- if module.startswith("maubot"):
+ if module.startswith("maubot."):
return f"{MAU_COLOR}{module}{RESET}"
return super()._color_name(module)
diff --git a/maubot/lib/store_proxy.py b/maubot/lib/state_store.py
similarity index 64%
rename from maubot/lib/store_proxy.py
rename to maubot/lib/state_store.py
index d8fa234..81fb5fd 100644
--- a/maubot/lib/store_proxy.py
+++ b/maubot/lib/state_store.py
@@ -13,16 +13,15 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from mautrix.client import SyncStore
-from mautrix.types import SyncToken
+from mautrix.client.state_store.asyncpg import PgStateStore as BasePgStateStore
+try:
+ from mautrix.crypto import StateStore as CryptoStateStore
-class SyncStoreProxy(SyncStore):
- def __init__(self, db_instance) -> None:
- self.db_instance = db_instance
+ class PgStateStore(BasePgStateStore, CryptoStateStore):
+ pass
- async def put_next_batch(self, next_batch: SyncToken) -> None:
- self.db_instance.edit(next_batch=next_batch)
+except ImportError as e:
+ PgStateStore = BasePgStateStore
- async def get_next_batch(self) -> SyncToken:
- return self.db_instance.next_batch
+__all__ = ["PgStateStore"]
diff --git a/maubot/loader/abc.py b/maubot/loader/abc.py
index f99358c..c669398 100644
--- a/maubot/loader/abc.py
+++ b/maubot/loader/abc.py
@@ -13,17 +13,14 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import TYPE_CHECKING, Dict, List, Set, Type, TypeVar
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, TypeVar
from abc import ABC, abstractmethod
import asyncio
-from attr import dataclass
-from packaging.version import InvalidVersion, Version
-
-from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
-
-from ..__meta__ import __version__
from ..plugin_base import Plugin
+from .meta import PluginMeta
if TYPE_CHECKING:
from ..instance import PluginInstance
@@ -35,36 +32,6 @@ class IDConflictError(Exception):
pass
-@serializer(Version)
-def serialize_version(version: Version) -> str:
- return str(version)
-
-
-@deserializer(Version)
-def deserialize_version(version: str) -> Version:
- try:
- return Version(version)
- except InvalidVersion as e:
- raise SerializerError("Invalid version") from e
-
-
-@dataclass
-class PluginMeta(SerializableAttrs):
- id: str
- version: Version
- modules: List[str]
- main_class: str
-
- maubot: Version = Version(__version__)
- database: bool = False
- config: bool = False
- webapp: bool = False
- license: str = ""
- extra_files: List[str] = []
- dependencies: List[str] = []
- soft_dependencies: List[str] = []
-
-
class BasePluginLoader(ABC):
meta: PluginMeta
@@ -80,25 +47,25 @@ class BasePluginLoader(ABC):
async def read_file(self, path: str) -> bytes:
pass
- def sync_list_files(self, directory: str) -> List[str]:
+ def sync_list_files(self, directory: str) -> list[str]:
raise NotImplementedError("This loader doesn't support synchronous operations")
@abstractmethod
- async def list_files(self, directory: str) -> List[str]:
+ async def list_files(self, directory: str) -> list[str]:
pass
class PluginLoader(BasePluginLoader, ABC):
- id_cache: Dict[str, "PluginLoader"] = {}
+ id_cache: dict[str, PluginLoader] = {}
meta: PluginMeta
- references: Set["PluginInstance"]
+ references: set[PluginInstance]
def __init__(self):
self.references = set()
@classmethod
- def find(cls, plugin_id: str) -> "PluginLoader":
+ def find(cls, plugin_id: str) -> PluginLoader:
return cls.id_cache[plugin_id]
def to_dict(self) -> dict:
@@ -119,11 +86,11 @@ class PluginLoader(BasePluginLoader, ABC):
)
@abstractmethod
- async def load(self) -> Type[PluginClass]:
+ async def load(self) -> type[PluginClass]:
pass
@abstractmethod
- async def reload(self) -> Type[PluginClass]:
+ async def reload(self) -> type[PluginClass]:
pass
@abstractmethod
diff --git a/maubot/loader/meta.py b/maubot/loader/meta.py
new file mode 100644
index 0000000..7d44483
--- /dev/null
+++ b/maubot/loader/meta.py
@@ -0,0 +1,53 @@
+# maubot - A plugin-based Matrix bot system.
+# Copyright (C) 2022 Tulir Asokan
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU Affero General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU Affero General Public License for more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with this program. If not, see .
+from typing import List
+
+from attr import dataclass
+from packaging.version import InvalidVersion, Version
+
+from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
+
+from ..__meta__ import __version__
+
+
+@serializer(Version)
+def serialize_version(version: Version) -> str:
+ return str(version)
+
+
+@deserializer(Version)
+def deserialize_version(version: str) -> Version:
+ try:
+ return Version(version)
+ except InvalidVersion as e:
+ raise SerializerError("Invalid version") from e
+
+
+@dataclass
+class PluginMeta(SerializableAttrs):
+ id: str
+ version: Version
+ modules: List[str]
+ main_class: str
+
+ maubot: Version = Version(__version__)
+ database: bool = False
+ config: bool = False
+ webapp: bool = False
+ license: str = ""
+ extra_files: List[str] = []
+ dependencies: List[str] = []
+ soft_dependencies: List[str] = []
diff --git a/maubot/loader/zip.py b/maubot/loader/zip.py
index 62db112..739656f 100644
--- a/maubot/loader/zip.py
+++ b/maubot/loader/zip.py
@@ -29,7 +29,8 @@ from mautrix.types import SerializerError
from ..config import Config
from ..lib.zipimport import ZipImportError, zipimporter
from ..plugin_base import Plugin
-from .abc import IDConflictError, PluginClass, PluginLoader, PluginMeta
+from .abc import IDConflictError, PluginClass, PluginLoader
+from .meta import PluginMeta
yaml = YAML()
diff --git a/maubot/management/api/__init__.py b/maubot/management/api/__init__.py
index 1c4d7d3..c2e5f24 100644
--- a/maubot/management/api/__init__.py
+++ b/maubot/management/api/__init__.py
@@ -20,7 +20,7 @@ from aiohttp import web
from ...config import Config
from .auth import check_token
-from .base import get_config, routes, set_config, set_loop
+from .base import get_config, routes, set_config
from .middleware import auth, error
@@ -40,7 +40,6 @@ def features(request: web.Request) -> web.Response:
def init(cfg: Config, loop: AbstractEventLoop) -> web.Application:
set_config(cfg)
- set_loop(loop)
for pkg, enabled in cfg["api_features"].items():
if enabled:
importlib.import_module(f"maubot.management.api.{pkg}")
diff --git a/maubot/management/api/auth.py b/maubot/management/api/auth.py
index 76ddcf3..0abc3ad 100644
--- a/maubot/management/api/auth.py
+++ b/maubot/management/api/auth.py
@@ -46,7 +46,7 @@ def create_token(user: UserID) -> str:
def get_token(request: web.Request) -> str:
token = request.headers.get("Authorization", "")
if not token or not token.startswith("Bearer "):
- token = request.query.get("access_token", None)
+ token = request.query.get("access_token", "")
else:
token = token[len("Bearer ") :]
return token
diff --git a/maubot/management/api/base.py b/maubot/management/api/base.py
index 73b2508..3d7693a 100644
--- a/maubot/management/api/base.py
+++ b/maubot/management/api/base.py
@@ -24,7 +24,6 @@ from ...config import Config
routes: web.RouteTableDef = web.RouteTableDef()
_config: Config | None = None
-_loop: asyncio.AbstractEventLoop | None = None
def set_config(config: Config) -> None:
@@ -36,15 +35,6 @@ def get_config() -> Config:
return _config
-def set_loop(loop: asyncio.AbstractEventLoop) -> None:
- global _loop
- _loop = loop
-
-
-def get_loop() -> asyncio.AbstractEventLoop:
- return _loop
-
-
@routes.get("/version")
async def version(_: web.Request) -> web.Response:
return web.json_response({"version": __version__})
diff --git a/maubot/management/api/client.py b/maubot/management/api/client.py
index 0b3a239..d95286b 100644
--- a/maubot/management/api/client.py
+++ b/maubot/management/api/client.py
@@ -24,7 +24,6 @@ from mautrix.errors import MatrixConnectionError, MatrixInvalidToken, MatrixRequ
from mautrix.types import FilterID, SyncToken, UserID
from ...client import Client
-from ...db import DBClient
from .base import routes
from .responses import resp
@@ -37,7 +36,7 @@ async def get_clients(_: web.Request) -> web.Response:
@routes.get("/client/{id}")
async def get_client(request: web.Request) -> web.Response:
user_id = request.match_info.get("id", None)
- client = Client.get(user_id, None)
+ client = await Client.get(user_id)
if not client:
return resp.client_not_found
return resp.found(client.to_dict())
@@ -51,7 +50,6 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
mxid="@not:a.mxid",
base_url=homeserver,
token=access_token,
- loop=Client.loop,
client_session=Client.http_client,
)
try:
@@ -63,29 +61,23 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
except MatrixConnectionError:
return resp.bad_client_connection_details
if user_id is None:
- existing_client = Client.get(whoami.user_id, None)
+ existing_client = await Client.get(whoami.user_id)
if existing_client is not None:
return resp.user_exists
elif whoami.user_id != user_id:
return resp.mxid_mismatch(whoami.user_id)
elif whoami.device_id and device_id and whoami.device_id != device_id:
return resp.device_id_mismatch(whoami.device_id)
- db_instance = DBClient(
- id=whoami.user_id,
- homeserver=homeserver,
- access_token=access_token,
- enabled=data.get("enabled", True),
- next_batch=SyncToken(""),
- filter_id=FilterID(""),
- sync=data.get("sync", True),
- autojoin=data.get("autojoin", True),
- online=data.get("online", True),
- displayname=data.get("displayname", "disable"),
- avatar_url=data.get("avatar_url", "disable"),
- device_id=device_id,
+ client = await Client.get(
+ whoami.user_id, homeserver=homeserver, access_token=access_token, device_id=device_id
)
- client = Client(db_instance)
- client.db_instance.insert()
+ client.enabled = data.get("enabled", True)
+ client.sync = data.get("sync", True)
+ client.autojoin = data.get("autojoin", True)
+ client.online = data.get("online", True)
+ client.displayname = data.get("displayname", "disable")
+ client.avatar_url = data.get("avatar_url", "disable")
+ await client.update()
await client.start()
return resp.created(client.to_dict())
@@ -93,9 +85,7 @@ async def _create_client(user_id: UserID | None, data: dict) -> web.Response:
async def _update_client(client: Client, data: dict, is_login: bool = False) -> web.Response:
try:
await client.update_access_details(
- data.get("access_token", None),
- data.get("homeserver", None),
- data.get("device_id", None),
+ data.get("access_token"), data.get("homeserver"), data.get("device_id")
)
except MatrixInvalidToken:
return resp.bad_client_access_token
@@ -109,21 +99,21 @@ async def _update_client(client: Client, data: dict, is_login: bool = False) ->
return resp.mxid_mismatch(str(e)[len("MXID mismatch: ") :])
elif str_err.startswith("Device ID mismatch"):
return resp.device_id_mismatch(str(e)[len("Device ID mismatch: ") :])
- with client.db_instance.edit_mode():
- await client.update_avatar_url(data.get("avatar_url", None))
- await client.update_displayname(data.get("displayname", None))
- await client.update_started(data.get("started", None))
- client.enabled = data.get("enabled", client.enabled)
- client.autojoin = data.get("autojoin", client.autojoin)
- client.online = data.get("online", client.online)
- client.sync = data.get("sync", client.sync)
- return resp.updated(client.to_dict(), is_login=is_login)
+ await client.update_avatar_url(data.get("avatar_url"), save=False)
+ await client.update_displayname(data.get("displayname"), save=False)
+ await client.update_started(data.get("started"))
+ await client.update_enabled(data.get("enabled"), save=False)
+ await client.update_autojoin(data.get("autojoin"), save=False)
+ await client.update_online(data.get("online"), save=False)
+ await client.update_sync(data.get("sync"), save=False)
+ await client.update()
+ return resp.updated(client.to_dict(), is_login=is_login)
async def _create_or_update_client(
user_id: UserID, data: dict, is_login: bool = False
) -> web.Response:
- client = Client.get(user_id, None)
+ client = await Client.get(user_id)
if not client:
return await _create_client(user_id, data)
else:
@@ -141,7 +131,7 @@ async def create_client(request: web.Request) -> web.Response:
@routes.put("/client/{id}")
async def update_client(request: web.Request) -> web.Response:
- user_id = request.match_info.get("id", None)
+ user_id = request.match_info["id"]
try:
data = await request.json()
except JSONDecodeError:
@@ -151,23 +141,23 @@ async def update_client(request: web.Request) -> web.Response:
@routes.delete("/client/{id}")
async def delete_client(request: web.Request) -> web.Response:
- user_id = request.match_info.get("id", None)
- client = Client.get(user_id, None)
+ user_id = request.match_info["id"]
+ client = await Client.get(user_id)
if not client:
return resp.client_not_found
if len(client.references) > 0:
return resp.client_in_use
if client.started:
await client.stop()
- client.delete()
+ await client.delete()
return resp.deleted
@routes.post("/client/{id}/clearcache")
async def clear_client_cache(request: web.Request) -> web.Response:
- user_id = request.match_info.get("id", None)
- client = Client.get(user_id, None)
+ user_id = request.match_info["id"]
+ client = await Client.get(user_id)
if not client:
return resp.client_not_found
- client.clear_cache()
+ await client.clear_cache()
return resp.ok
diff --git a/maubot/management/api/client_auth.py b/maubot/management/api/client_auth.py
index 754c0d7..c5baade 100644
--- a/maubot/management/api/client_auth.py
+++ b/maubot/management/api/client_auth.py
@@ -13,7 +13,9 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see .
-from typing import Dict, NamedTuple, Optional, Tuple
+from __future__ import annotations
+
+from typing import NamedTuple
from http import HTTPStatus
from json import JSONDecodeError
import asyncio
@@ -30,12 +32,12 @@ from mautrix.client import ClientAPI
from mautrix.errors import MatrixRequestError
from mautrix.types import LoginResponse, LoginType
-from .base import get_config, get_loop, routes
+from .base import get_config, routes
from .client import _create_client, _create_or_update_client
from .responses import resp
-def known_homeservers() -> Dict[str, Dict[str, str]]:
+def known_homeservers() -> dict[str, dict[str, str]]:
return get_config()["homeservers"]
@@ -61,7 +63,7 @@ truthy_strings = ("1", "true", "yes")
async def read_client_auth_request(
request: web.Request,
-) -> Tuple[Optional[AuthRequestInfo], Optional[web.Response]]:
+) -> tuple[AuthRequestInfo | None, web.Response | None]:
server_name = request.match_info.get("server", None)
server = known_homeservers().get(server_name, None)
if not server:
@@ -85,7 +87,7 @@ async def read_client_auth_request(
return (
AuthRequestInfo(
server_name=server_name,
- client=ClientAPI(base_url=base_url, loop=get_loop()),
+ client=ClientAPI(base_url=base_url),
secret=server.get("secret"),
username=username,
password=password,
@@ -189,11 +191,11 @@ async def _do_sso(req: AuthRequestInfo) -> web.Response:
sso_url = req.client.api.base_url.with_path(str(Path.login.sso.redirect)).with_query(
{"redirectUrl": str(public_url)}
)
- sso_waiters[waiter_id] = req, get_loop().create_future()
+ sso_waiters[waiter_id] = req, asyncio.get_running_loop().create_future()
return web.json_response({"sso_url": str(sso_url), "id": waiter_id})
-async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) -> web.Response:
+async def _do_login(req: AuthRequestInfo, login_token: str | None = None) -> web.Response:
device_id = "".join(random.choices(string.ascii_uppercase + string.digits, k=8))
device_id = f"maubot_{device_id}"
try:
@@ -235,7 +237,7 @@ async def _do_login(req: AuthRequestInfo, login_token: Optional[str] = None) ->
return web.json_response(res.serialize())
-sso_waiters: Dict[str, Tuple[AuthRequestInfo, asyncio.Future]] = {}
+sso_waiters: dict[str, tuple[AuthRequestInfo, asyncio.Future]] = {}
@routes.post("/client/auth/{server}/sso/{id}/wait")
diff --git a/maubot/management/api/client_proxy.py b/maubot/management/api/client_proxy.py
index dca741f..3fa682b 100644
--- a/maubot/management/api/client_proxy.py
+++ b/maubot/management/api/client_proxy.py
@@ -25,7 +25,7 @@ PROXY_CHUNK_SIZE = 32 * 1024
@routes.view("/proxy/{id}/{path:_matrix/.+}")
async def proxy(request: web.Request) -> web.StreamResponse:
user_id = request.match_info.get("id", None)
- client = Client.get(user_id, None)
+ client = await Client.get(user_id)
if not client:
return resp.client_not_found
diff --git a/maubot/management/api/instance.py b/maubot/management/api/instance.py
index c875c6a..edc34bd 100644
--- a/maubot/management/api/instance.py
+++ b/maubot/management/api/instance.py
@@ -18,7 +18,6 @@ from json import JSONDecodeError
from aiohttp import web
from ...client import Client
-from ...db import DBPlugin
from ...instance import PluginInstance
from ...loader import PluginLoader
from .base import routes
@@ -32,56 +31,49 @@ async def get_instances(_: web.Request) -> web.Response:
@routes.get("/instance/{id}")
async def get_instance(request: web.Request) -> web.Response:
- instance_id = request.match_info.get("id", "").lower()
- instance = PluginInstance.get(instance_id, None)
+ instance_id = request.match_info["id"].lower()
+ instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
return resp.found(instance.to_dict())
async def _create_instance(instance_id: str, data: dict) -> web.Response:
- plugin_type = data.get("type", None)
- primary_user = data.get("primary_user", None)
+ plugin_type = data.get("type")
+ primary_user = data.get("primary_user")
if not plugin_type:
return resp.plugin_type_required
elif not primary_user:
return resp.primary_user_required
- elif not Client.get(primary_user):
+ elif not await Client.get(primary_user):
return resp.primary_user_not_found
try:
PluginLoader.find(plugin_type)
except KeyError:
return resp.plugin_type_not_found
- db_instance = DBPlugin(
- id=instance_id,
- type=plugin_type,
- enabled=data.get("enabled", True),
- primary_user=primary_user,
- config=data.get("config", ""),
- )
- instance = PluginInstance(db_instance)
- instance.load()
- instance.db_instance.insert()
+ instance = await PluginInstance.get(instance_id, type=plugin_type, primary_user=primary_user)
+ instance.enabled = data.get("enabled", True)
+ instance.config_str = data.get("config") or ""
+ await instance.update()
await instance.start()
return resp.created(instance.to_dict())
async def _update_instance(instance: PluginInstance, data: dict) -> web.Response:
- if not await instance.update_primary_user(data.get("primary_user", None)):
+ if not await instance.update_primary_user(data.get("primary_user")):
return resp.primary_user_not_found
- with instance.db_instance.edit_mode():
- instance.update_id(data.get("id", None))
- instance.update_enabled(data.get("enabled", None))
- instance.update_config(data.get("config", None))
- await instance.update_started(data.get("started", None))
- await instance.update_type(data.get("type", None))
- return resp.updated(instance.to_dict())
+ await instance.update_id(data.get("id"))
+ await instance.update_enabled(data.get("enabled"))
+ await instance.update_config(data.get("config"))
+ await instance.update_started(data.get("started"))
+ await instance.update_type(data.get("type"))
+ return resp.updated(instance.to_dict())
@routes.put("/instance/{id}")
async def update_instance(request: web.Request) -> web.Response:
- instance_id = request.match_info.get("id", "").lower()
- instance = PluginInstance.get(instance_id, None)
+ instance_id = request.match_info["id"].lower()
+ instance = await PluginInstance.get(instance_id)
try:
data = await request.json()
except JSONDecodeError:
@@ -94,11 +86,11 @@ async def update_instance(request: web.Request) -> web.Response:
@routes.delete("/instance/{id}")
async def delete_instance(request: web.Request) -> web.Response:
- instance_id = request.match_info.get("id", "").lower()
- instance = PluginInstance.get(instance_id)
+ instance_id = request.match_info["id"].lower()
+ instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
if instance.started:
await instance.stop()
- instance.delete()
+ await instance.delete()
return resp.deleted
diff --git a/maubot/management/api/instance_database.py b/maubot/management/api/instance_database.py
index ef7da30..25869ce 100644
--- a/maubot/management/api/instance_database.py
+++ b/maubot/management/api/instance_database.py
@@ -29,8 +29,8 @@ from .responses import resp
@routes.get("/instance/{id}/database")
async def get_database(request: web.Request) -> web.Response:
- instance_id = request.match_info.get("id", "")
- instance = PluginInstance.get(instance_id, None)
+ instance_id = request.match_info["id"].lower()
+ instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:
@@ -65,8 +65,8 @@ def check_type(val):
@routes.get("/instance/{id}/database/{table}")
async def get_table(request: web.Request) -> web.Response:
- instance_id = request.match_info.get("id", "")
- instance = PluginInstance.get(instance_id, None)
+ instance_id = request.match_info["id"].lower()
+ instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:
@@ -86,14 +86,14 @@ async def get_table(request: web.Request) -> web.Response:
]
except KeyError:
order = []
- limit = int(request.query.get("limit", 100))
+ limit = int(request.query.get("limit", "100"))
return execute_query(instance, table.select().order_by(*order).limit(limit))
@routes.post("/instance/{id}/database/query")
async def query(request: web.Request) -> web.Response:
- instance_id = request.match_info.get("id", "")
- instance = PluginInstance.get(instance_id, None)
+ instance_id = request.match_info["id"].lower()
+ instance = await PluginInstance.get(instance_id)
if not instance:
return resp.instance_not_found
elif not instance.inst_db:
diff --git a/maubot/management/api/log.py b/maubot/management/api/log.py
index 1c5df93..05c11d3 100644
--- a/maubot/management/api/log.py
+++ b/maubot/management/api/log.py
@@ -23,7 +23,7 @@ import logging
from aiohttp import web, web_ws
from .auth import is_valid_token
-from .base import get_loop, routes
+from .base import routes
BUILTIN_ATTRS = {
"args",
@@ -138,12 +138,12 @@ async def log_websocket(request: web.Request) -> web.WebSocketResponse:
authenticated = False
async def close_if_not_authenticated():
- await asyncio.sleep(5, loop=get_loop())
+ await asyncio.sleep(5)
if not authenticated:
await ws.close(code=4000)
log.debug(f"Connection from {request.remote} terminated due to no authentication")
- asyncio.ensure_future(close_if_not_authenticated())
+ asyncio.create_task(close_if_not_authenticated())
try:
msg: web_ws.WSMessage
diff --git a/maubot/management/api/plugin.py b/maubot/management/api/plugin.py
index ecd3c6a..94d8d9d 100644
--- a/maubot/management/api/plugin.py
+++ b/maubot/management/api/plugin.py
@@ -29,8 +29,8 @@ async def get_plugins(_) -> web.Response:
@routes.get("/plugin/{id}")
async def get_plugin(request: web.Request) -> web.Response:
- plugin_id = request.match_info.get("id", None)
- plugin = PluginLoader.id_cache.get(plugin_id, None)
+ plugin_id = request.match_info["id"]
+ plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin:
return resp.plugin_not_found
return resp.found(plugin.to_dict())
@@ -38,8 +38,8 @@ async def get_plugin(request: web.Request) -> web.Response:
@routes.delete("/plugin/{id}")
async def delete_plugin(request: web.Request) -> web.Response:
- plugin_id = request.match_info.get("id", None)
- plugin = PluginLoader.id_cache.get(plugin_id, None)
+ plugin_id = request.match_info["id"]
+ plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin:
return resp.plugin_not_found
elif len(plugin.references) > 0:
@@ -50,8 +50,8 @@ async def delete_plugin(request: web.Request) -> web.Response:
@routes.post("/plugin/{id}/reload")
async def reload_plugin(request: web.Request) -> web.Response:
- plugin_id = request.match_info.get("id", None)
- plugin = PluginLoader.id_cache.get(plugin_id, None)
+ plugin_id = request.match_info["id"]
+ plugin = PluginLoader.id_cache.get(plugin_id)
if not plugin:
return resp.plugin_not_found
diff --git a/maubot/management/api/plugin_upload.py b/maubot/management/api/plugin_upload.py
index f187c71..ffedbb8 100644
--- a/maubot/management/api/plugin_upload.py
+++ b/maubot/management/api/plugin_upload.py
@@ -29,7 +29,7 @@ from .responses import resp
@routes.put("/plugin/{id}")
async def put_plugin(request: web.Request) -> web.Response:
- plugin_id = request.match_info.get("id", None)
+ plugin_id = request.match_info["id"]
content = await request.read()
file = BytesIO(content)
try:
diff --git a/maubot/plugin_base.py b/maubot/plugin_base.py
index f6bb578..3fae788 100644
--- a/maubot/plugin_base.py
+++ b/maubot/plugin_base.py
@@ -15,7 +15,7 @@
# along with this program. If not, see .
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Awaitable
from abc import ABC
from asyncio import AbstractEventLoop
@@ -124,6 +124,7 @@ class Plugin(ABC):
def get_config_class(cls) -> type[BaseProxyConfig] | None:
return None
- def on_external_config_update(self) -> None:
+ def on_external_config_update(self) -> Awaitable[None] | None:
if self.config:
self.config.load_and_update()
+ return None
diff --git a/optional-requirements.txt b/optional-requirements.txt
index 0397722..f42cab6 100644
--- a/optional-requirements.txt
+++ b/optional-requirements.txt
@@ -1,13 +1,10 @@
# Format: #/name defines a new extras_require group called name
# Uncommented lines after the group definition insert things into that group.
-#/postgres
-psycopg2-binary>=2,<3
-asyncpg>=0.20,<0.26
+#/sqlite
+aiosqlite>=0.16,<0.18
#/encryption
-asyncpg>=0.20,<0.26
-aiosqlite>=0.16,<0.18
python-olm>=3,<4
pycryptodome>=3,<4
unpaddedbase64>=1,<3
diff --git a/requirements.txt b/requirements.txt
index ee3bc37..dd541e4 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,8 @@
-mautrix>=0.15.0,<0.16
+mautrix>=0.15.2,<0.16
aiohttp>=3,<4
yarl>=1,<2
SQLAlchemy>=1,<1.4
+asyncpg>=0.20,<0.26
alembic>=1,<2
commonmark>=0.9,<1
ruamel.yaml>=0.15.35,<0.18
diff --git a/setup.py b/setup.py
index 574a1c6..cba8c20 100644
--- a/setup.py
+++ b/setup.py
@@ -1,5 +1,4 @@
import setuptools
-import glob
import os
with open("requirements.txt") as reqs:
@@ -57,9 +56,7 @@ setuptools.setup(
mbc=maubot.cli:app
""",
data_files=[
- (".", ["maubot/example-config.yaml", "alembic.ini"]),
- ("alembic", ["alembic/env.py"]),
- ("alembic/versions", glob.glob("alembic/versions/*.py")),
+ (".", ["maubot/example-config.yaml"]),
],
package_data={
"maubot": [