diff --git a/maubot/standalone/__main__.py b/maubot/standalone/__main__.py index 81bba1e..97ac271 100644 --- a/maubot/standalone/__main__.py +++ b/maubot/standalone/__main__.py @@ -29,24 +29,20 @@ from aiohttp import ClientSession, hdrs, web from ruamel.yaml import YAML from ruamel.yaml.comments import CommentedMap from yarl import URL -import sqlalchemy as sql from mautrix.types import ( EventType, Filter, - FilterID, Membership, RoomEventFilter, RoomFilter, StrippedStateEvent, - SyncToken, ) +from mautrix.util.async_db import Database from mautrix.util.config import BaseMissingError, RecursiveDict -from mautrix.util.db import Base from mautrix.util.logging import TraceLogger from ..__meta__ import __version__ -from ..lib.store_proxy import SyncStoreProxy from ..loader import PluginMeta from ..matrix import MaubotMatrixClient from ..plugin_base import Plugin @@ -60,10 +56,9 @@ crypto_import_error = None try: from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore - from mautrix.util.async_db import Database as AsyncDatabase except ImportError as err: crypto_import_error = err - OlmMachine = AsyncDatabase = PgCryptoStateStore = PgCryptoStore = None + OlmMachine = PgCryptoStateStore = PgCryptoStore = None parser = argparse.ArgumentParser( description="A plugin-based Matrix bot system -- standalone mode.", @@ -124,32 +119,26 @@ loader = FileSystemLoader(os.path.dirname(args.meta)) log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}") -log.debug("Opening database") -db = sql.create_engine(config["database"]) -Base.metadata.bind = db -Base.metadata.create_all() -NextBatch.bind(db) +db = Database.create( + config["database"], + db_args=config.get("database_opts", {}), + ignore_foreign_tables=True, +) user_id = config["user.credentials.id"] device_id = config["user.credentials.device_id"] homeserver = config["user.credentials.homeserver"] access_token = config["user.credentials.access_token"] -crypto_store = crypto_db = state_store = None +crypto_store = state_store = None if device_id and not OlmMachine: log.warning( "device_id set in config, but encryption dependencies not installed", exc_info=crypto_import_error, ) elif device_id: - crypto_db = AsyncDatabase.create(config["database"], upgrade_table=PgCryptoStore.upgrade_table) - crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=crypto_db) - state_store = PgCryptoStateStore(crypto_db) - -nb = NextBatch.get(user_id) -if not nb: - nb = NextBatch(user_id=user_id, next_batch=SyncToken(""), filter_id=FilterID("")) - nb.insert() + crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=db) + state_store = PgCryptoStateStore(db) bot_config = None if not meta.config and "base-config.yaml" in meta.extra_files: @@ -188,11 +177,9 @@ if meta.webapp: async def _handle_plugin_request(req: web.Request) -> web.StreamResponse: if req.path.startswith(web_base_path): - req = req.clone( - rel_url=req.rel_url.with_path(req.rel_url.path[len(web_base_path) :]).with_query( - req.query_string - ) - ) + path_override = req.rel_url.path[len(web_base_path) :] + url_override = req.rel_url.with_path(path_override).with_query(req.query_string) + req = req.clone(rel_url=url_override) return await plugin_webapp.handle(req) return web.Response(status=404) @@ -213,6 +200,9 @@ async def main(): global client, bot + await db.start() + nb = await NextBatch(db, user_id).load() + client_log = logging.getLogger("maubot.client").getChild(user_id) client = MaubotMatrixClient( mxid=user_id, @@ -221,15 +211,15 @@ async def main(): client_session=http_client, loop=loop, log=client_log, - sync_store=SyncStoreProxy(nb), + sync_store=nb, state_store=state_store, device_id=device_id, ) client.ignore_first_sync = config["user.ignore_first_sync"] client.ignore_initial_sync = config["user.ignore_initial_sync"] if crypto_store: - await crypto_db.start() - await state_store.upgrade_table.upgrade(crypto_db) + await crypto_store.upgrade_table.upgrade(db) + await state_store.upgrade_table.upgrade(db) await crypto_store.open() client.crypto = OlmMachine(client, crypto_store, state_store) @@ -254,8 +244,11 @@ async def main(): while True: try: whoami = await client.whoami() - except Exception: - log.exception("Failed to connect to homeserver, retrying in 10 seconds...") + except Exception as e: + log.error( + f"Failed to connect to homeserver: {type(e).__name__}: {e}" + " - retrying in 10 seconds..." + ) await asyncio.sleep(10) continue if whoami.user_id != user_id: @@ -272,14 +265,11 @@ async def main(): if config["user.sync"]: if not nb.filter_id: - nb.edit( - filter_id=await client.create_filter( - Filter( - room=RoomFilter(timeline=RoomEventFilter(limit=50)), - ) - ) + filter_id = await client.create_filter( + Filter(room=RoomFilter(timeline=RoomEventFilter(limit=50))) ) - client.start(nb.filter_id) + await nb.put_filter_id(filter_id) + _ = client.start(nb.filter_id) if config["user.autojoin"]: log.debug("Autojoin is enabled") @@ -321,25 +311,28 @@ async def stop(suppress_stop_error: bool = False) -> None: except Exception: if not suppress_stop_error: log.exception("Error stopping bot") - if crypto_db: - await crypto_db.stop() if web_runner: await web_runner.shutdown() await web_runner.cleanup() + await db.stop() + + +signal.signal(signal.SIGINT, signal.default_int_handler) +signal.signal(signal.SIGTERM, signal.default_int_handler) try: log.info("Starting plugin") loop.run_until_complete(main()) -except Exception: - log.fatal("Failed to start plugin", exc_info=True) +except (Exception, KeyboardInterrupt) as e: + if isinstance(e, KeyboardInterrupt): + log.info("Startup interrupted, stopping") + else: + log.fatal("Failed to start plugin", exc_info=True) loop.run_until_complete(stop(suppress_stop_error=True)) loop.close() sys.exit(1) -signal.signal(signal.SIGINT, signal.default_int_handler) -signal.signal(signal.SIGTERM, signal.default_int_handler) - try: log.info("Startup completed, running forever") loop.run_forever() diff --git a/maubot/standalone/config.py b/maubot/standalone/config.py index 7c977ff..ce17310 100644 --- a/maubot/standalone/config.py +++ b/maubot/standalone/config.py @@ -41,8 +41,8 @@ class Config(BaseFileConfig): copy("server.port") copy("server.base_path") copy("server.public_url") - if "database" in base: - copy("database") + copy("database") + copy("database_opts") if "plugin_config" in base: copy("plugin_config") copy("logging") diff --git a/maubot/standalone/database.py b/maubot/standalone/database.py index cdc3525..e69a220 100644 --- a/maubot/standalone/database.py +++ b/maubot/standalone/database.py @@ -15,19 +15,41 @@ # along with this program. If not, see . from __future__ import annotations -import sqlalchemy as sql +from attr import dataclass +from mautrix.client import SyncStore from mautrix.types import FilterID, SyncToken, UserID -from mautrix.util.db import Base +from mautrix.util.async_db import Database + +find_q = "SELECT next_batch, filter_id FROM standalone_next_batch WHERE user_id=$1" +insert_q = "INSERT INTO standalone_next_batch (user_id, next_batch, filter_id) VALUES ($1, $2, $3)" +update_nb_q = "UPDATE standalone_next_batch SET next_batch=$1 WHERE user_id=$2" +update_filter_q = "UPDATE standalone_next_batch SET filter_id=$1 WHERE user_id=$2" -class NextBatch(Base): - __tablename__ = "standalone_next_batch" +@dataclass +class NextBatch(SyncStore): + db: Database + user_id: UserID + next_batch: SyncToken = "" + filter_id: FilterID = "" - user_id: UserID = sql.Column(sql.String(255), primary_key=True) - next_batch: SyncToken = sql.Column(sql.String(255)) - filter_id: FilterID = sql.Column(sql.String(255)) + async def load(self) -> NextBatch: + row = await self.db.fetchrow(find_q, self.user_id) + if row is not None: + self.next_batch = row["next_batch"] + self.filter_id = row["filter_id"] + else: + await self.db.execute(insert_q, self.user_id, self.next_batch, self.filter_id) + return self - @classmethod - def get(cls, user_id: UserID) -> NextBatch | None: - return cls._select_one_or_none(cls.c.user_id == user_id) + async def put_filter_id(self, filter_id: FilterID) -> None: + self.filter_id = filter_id + await self.db.execute(update_filter_q, self.filter_id, self.user_id) + + async def put_next_batch(self, next_batch: SyncToken) -> None: + self.next_batch = next_batch + await self.db.execute(update_nb_q, self.next_batch, self.user_id) + + async def get_next_batch(self) -> SyncToken: + return self.next_batch diff --git a/maubot/standalone/example-config.yaml b/maubot/standalone/example-config.yaml index ffdf699..1884b78 100644 --- a/maubot/standalone/example-config.yaml +++ b/maubot/standalone/example-config.yaml @@ -37,6 +37,14 @@ server: # SQLite and Postgres are supported. database: sqlite:///bot.db +# Additional arguments for asyncpg.create_pool() or sqlite3.connect() +# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool +# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect +# For sqlite, min_size is used as the connection thread pool size and max_size is ignored. +database_opts: + min_size: 1 + max_size: 10 + # Config for the plugin. Refer to the plugin's base-config.yaml to find what (if anything) to put here. plugin_config: {}