Update standalone mode to asyncpg/aiosqlite

This commit is contained in:
Tulir Asokan 2022-03-25 21:12:40 +02:00
parent c4f9a3bdf5
commit 29b4a3c892
4 changed files with 80 additions and 57 deletions

View file

@ -29,24 +29,20 @@ from aiohttp import ClientSession, hdrs, web
from ruamel.yaml import YAML from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap from ruamel.yaml.comments import CommentedMap
from yarl import URL from yarl import URL
import sqlalchemy as sql
from mautrix.types import ( from mautrix.types import (
EventType, EventType,
Filter, Filter,
FilterID,
Membership, Membership,
RoomEventFilter, RoomEventFilter,
RoomFilter, RoomFilter,
StrippedStateEvent, StrippedStateEvent,
SyncToken,
) )
from mautrix.util.async_db import Database
from mautrix.util.config import BaseMissingError, RecursiveDict from mautrix.util.config import BaseMissingError, RecursiveDict
from mautrix.util.db import Base
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
from ..__meta__ import __version__ from ..__meta__ import __version__
from ..lib.store_proxy import SyncStoreProxy
from ..loader import PluginMeta from ..loader import PluginMeta
from ..matrix import MaubotMatrixClient from ..matrix import MaubotMatrixClient
from ..plugin_base import Plugin from ..plugin_base import Plugin
@ -60,10 +56,9 @@ crypto_import_error = None
try: try:
from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore
from mautrix.util.async_db import Database as AsyncDatabase
except ImportError as err: except ImportError as err:
crypto_import_error = err crypto_import_error = err
OlmMachine = AsyncDatabase = PgCryptoStateStore = PgCryptoStore = None OlmMachine = PgCryptoStateStore = PgCryptoStore = None
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="A plugin-based Matrix bot system -- standalone mode.", description="A plugin-based Matrix bot system -- standalone mode.",
@ -124,32 +119,26 @@ loader = FileSystemLoader(os.path.dirname(args.meta))
log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}") log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}")
log.debug("Opening database") db = Database.create(
db = sql.create_engine(config["database"]) config["database"],
Base.metadata.bind = db db_args=config.get("database_opts", {}),
Base.metadata.create_all() ignore_foreign_tables=True,
NextBatch.bind(db) )
user_id = config["user.credentials.id"] user_id = config["user.credentials.id"]
device_id = config["user.credentials.device_id"] device_id = config["user.credentials.device_id"]
homeserver = config["user.credentials.homeserver"] homeserver = config["user.credentials.homeserver"]
access_token = config["user.credentials.access_token"] access_token = config["user.credentials.access_token"]
crypto_store = crypto_db = state_store = None crypto_store = state_store = None
if device_id and not OlmMachine: if device_id and not OlmMachine:
log.warning( log.warning(
"device_id set in config, but encryption dependencies not installed", "device_id set in config, but encryption dependencies not installed",
exc_info=crypto_import_error, exc_info=crypto_import_error,
) )
elif device_id: elif device_id:
crypto_db = AsyncDatabase.create(config["database"], upgrade_table=PgCryptoStore.upgrade_table) crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=db)
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=crypto_db) state_store = PgCryptoStateStore(db)
state_store = PgCryptoStateStore(crypto_db)
nb = NextBatch.get(user_id)
if not nb:
nb = NextBatch(user_id=user_id, next_batch=SyncToken(""), filter_id=FilterID(""))
nb.insert()
bot_config = None bot_config = None
if not meta.config and "base-config.yaml" in meta.extra_files: if not meta.config and "base-config.yaml" in meta.extra_files:
@ -188,11 +177,9 @@ if meta.webapp:
async def _handle_plugin_request(req: web.Request) -> web.StreamResponse: async def _handle_plugin_request(req: web.Request) -> web.StreamResponse:
if req.path.startswith(web_base_path): if req.path.startswith(web_base_path):
req = req.clone( path_override = req.rel_url.path[len(web_base_path) :]
rel_url=req.rel_url.with_path(req.rel_url.path[len(web_base_path) :]).with_query( url_override = req.rel_url.with_path(path_override).with_query(req.query_string)
req.query_string req = req.clone(rel_url=url_override)
)
)
return await plugin_webapp.handle(req) return await plugin_webapp.handle(req)
return web.Response(status=404) return web.Response(status=404)
@ -213,6 +200,9 @@ async def main():
global client, bot global client, bot
await db.start()
nb = await NextBatch(db, user_id).load()
client_log = logging.getLogger("maubot.client").getChild(user_id) client_log = logging.getLogger("maubot.client").getChild(user_id)
client = MaubotMatrixClient( client = MaubotMatrixClient(
mxid=user_id, mxid=user_id,
@ -221,15 +211,15 @@ async def main():
client_session=http_client, client_session=http_client,
loop=loop, loop=loop,
log=client_log, log=client_log,
sync_store=SyncStoreProxy(nb), sync_store=nb,
state_store=state_store, state_store=state_store,
device_id=device_id, device_id=device_id,
) )
client.ignore_first_sync = config["user.ignore_first_sync"] client.ignore_first_sync = config["user.ignore_first_sync"]
client.ignore_initial_sync = config["user.ignore_initial_sync"] client.ignore_initial_sync = config["user.ignore_initial_sync"]
if crypto_store: if crypto_store:
await crypto_db.start() await crypto_store.upgrade_table.upgrade(db)
await state_store.upgrade_table.upgrade(crypto_db) await state_store.upgrade_table.upgrade(db)
await crypto_store.open() await crypto_store.open()
client.crypto = OlmMachine(client, crypto_store, state_store) client.crypto = OlmMachine(client, crypto_store, state_store)
@ -254,8 +244,11 @@ async def main():
while True: while True:
try: try:
whoami = await client.whoami() whoami = await client.whoami()
except Exception: except Exception as e:
log.exception("Failed to connect to homeserver, retrying in 10 seconds...") log.error(
f"Failed to connect to homeserver: {type(e).__name__}: {e}"
" - retrying in 10 seconds..."
)
await asyncio.sleep(10) await asyncio.sleep(10)
continue continue
if whoami.user_id != user_id: if whoami.user_id != user_id:
@ -272,14 +265,11 @@ async def main():
if config["user.sync"]: if config["user.sync"]:
if not nb.filter_id: if not nb.filter_id:
nb.edit(
filter_id = await client.create_filter( filter_id = await client.create_filter(
Filter( Filter(room=RoomFilter(timeline=RoomEventFilter(limit=50)))
room=RoomFilter(timeline=RoomEventFilter(limit=50)),
) )
) await nb.put_filter_id(filter_id)
) _ = client.start(nb.filter_id)
client.start(nb.filter_id)
if config["user.autojoin"]: if config["user.autojoin"]:
log.debug("Autojoin is enabled") log.debug("Autojoin is enabled")
@ -321,25 +311,28 @@ async def stop(suppress_stop_error: bool = False) -> None:
except Exception: except Exception:
if not suppress_stop_error: if not suppress_stop_error:
log.exception("Error stopping bot") log.exception("Error stopping bot")
if crypto_db:
await crypto_db.stop()
if web_runner: if web_runner:
await web_runner.shutdown() await web_runner.shutdown()
await web_runner.cleanup() await web_runner.cleanup()
await db.stop()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
try: try:
log.info("Starting plugin") log.info("Starting plugin")
loop.run_until_complete(main()) loop.run_until_complete(main())
except Exception: except (Exception, KeyboardInterrupt) as e:
if isinstance(e, KeyboardInterrupt):
log.info("Startup interrupted, stopping")
else:
log.fatal("Failed to start plugin", exc_info=True) log.fatal("Failed to start plugin", exc_info=True)
loop.run_until_complete(stop(suppress_stop_error=True)) loop.run_until_complete(stop(suppress_stop_error=True))
loop.close() loop.close()
sys.exit(1) sys.exit(1)
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
try: try:
log.info("Startup completed, running forever") log.info("Startup completed, running forever")
loop.run_forever() loop.run_forever()

View file

@ -41,8 +41,8 @@ class Config(BaseFileConfig):
copy("server.port") copy("server.port")
copy("server.base_path") copy("server.base_path")
copy("server.public_url") copy("server.public_url")
if "database" in base:
copy("database") copy("database")
copy("database_opts")
if "plugin_config" in base: if "plugin_config" in base:
copy("plugin_config") copy("plugin_config")
copy("logging") copy("logging")

View file

@ -15,19 +15,41 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
import sqlalchemy as sql from attr import dataclass
from mautrix.client import SyncStore
from mautrix.types import FilterID, SyncToken, UserID from mautrix.types import FilterID, SyncToken, UserID
from mautrix.util.db import Base from mautrix.util.async_db import Database
find_q = "SELECT next_batch, filter_id FROM standalone_next_batch WHERE user_id=$1"
insert_q = "INSERT INTO standalone_next_batch (user_id, next_batch, filter_id) VALUES ($1, $2, $3)"
update_nb_q = "UPDATE standalone_next_batch SET next_batch=$1 WHERE user_id=$2"
update_filter_q = "UPDATE standalone_next_batch SET filter_id=$1 WHERE user_id=$2"
class NextBatch(Base): @dataclass
__tablename__ = "standalone_next_batch" class NextBatch(SyncStore):
db: Database
user_id: UserID
next_batch: SyncToken = ""
filter_id: FilterID = ""
user_id: UserID = sql.Column(sql.String(255), primary_key=True) async def load(self) -> NextBatch:
next_batch: SyncToken = sql.Column(sql.String(255)) row = await self.db.fetchrow(find_q, self.user_id)
filter_id: FilterID = sql.Column(sql.String(255)) if row is not None:
self.next_batch = row["next_batch"]
self.filter_id = row["filter_id"]
else:
await self.db.execute(insert_q, self.user_id, self.next_batch, self.filter_id)
return self
@classmethod async def put_filter_id(self, filter_id: FilterID) -> None:
def get(cls, user_id: UserID) -> NextBatch | None: self.filter_id = filter_id
return cls._select_one_or_none(cls.c.user_id == user_id) await self.db.execute(update_filter_q, self.filter_id, self.user_id)
async def put_next_batch(self, next_batch: SyncToken) -> None:
self.next_batch = next_batch
await self.db.execute(update_nb_q, self.next_batch, self.user_id)
async def get_next_batch(self) -> SyncToken:
return self.next_batch

View file

@ -37,6 +37,14 @@ server:
# SQLite and Postgres are supported. # SQLite and Postgres are supported.
database: sqlite:///bot.db database: sqlite:///bot.db
# Additional arguments for asyncpg.create_pool() or sqlite3.connect()
# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool
# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# For sqlite, min_size is used as the connection thread pool size and max_size is ignored.
database_opts:
min_size: 1
max_size: 10
# Config for the plugin. Refer to the plugin's base-config.yaml to find what (if anything) to put here. # Config for the plugin. Refer to the plugin's base-config.yaml to find what (if anything) to put here.
plugin_config: {} plugin_config: {}