Update standalone mode to asyncpg/aiosqlite

This commit is contained in:
Tulir Asokan 2022-03-25 21:12:40 +02:00
parent c4f9a3bdf5
commit 29b4a3c892
4 changed files with 80 additions and 57 deletions

View file

@ -29,24 +29,20 @@ from aiohttp import ClientSession, hdrs, web
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap
from yarl import URL
import sqlalchemy as sql
from mautrix.types import (
EventType,
Filter,
FilterID,
Membership,
RoomEventFilter,
RoomFilter,
StrippedStateEvent,
SyncToken,
)
from mautrix.util.async_db import Database
from mautrix.util.config import BaseMissingError, RecursiveDict
from mautrix.util.db import Base
from mautrix.util.logging import TraceLogger
from ..__meta__ import __version__
from ..lib.store_proxy import SyncStoreProxy
from ..loader import PluginMeta
from ..matrix import MaubotMatrixClient
from ..plugin_base import Plugin
@ -60,10 +56,9 @@ crypto_import_error = None
try:
from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore
from mautrix.util.async_db import Database as AsyncDatabase
except ImportError as err:
crypto_import_error = err
OlmMachine = AsyncDatabase = PgCryptoStateStore = PgCryptoStore = None
OlmMachine = PgCryptoStateStore = PgCryptoStore = None
parser = argparse.ArgumentParser(
description="A plugin-based Matrix bot system -- standalone mode.",
@ -124,32 +119,26 @@ loader = FileSystemLoader(os.path.dirname(args.meta))
log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}")
log.debug("Opening database")
db = sql.create_engine(config["database"])
Base.metadata.bind = db
Base.metadata.create_all()
NextBatch.bind(db)
db = Database.create(
config["database"],
db_args=config.get("database_opts", {}),
ignore_foreign_tables=True,
)
user_id = config["user.credentials.id"]
device_id = config["user.credentials.device_id"]
homeserver = config["user.credentials.homeserver"]
access_token = config["user.credentials.access_token"]
crypto_store = crypto_db = state_store = None
crypto_store = state_store = None
if device_id and not OlmMachine:
log.warning(
"device_id set in config, but encryption dependencies not installed",
exc_info=crypto_import_error,
)
elif device_id:
crypto_db = AsyncDatabase.create(config["database"], upgrade_table=PgCryptoStore.upgrade_table)
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=crypto_db)
state_store = PgCryptoStateStore(crypto_db)
nb = NextBatch.get(user_id)
if not nb:
nb = NextBatch(user_id=user_id, next_batch=SyncToken(""), filter_id=FilterID(""))
nb.insert()
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=db)
state_store = PgCryptoStateStore(db)
bot_config = None
if not meta.config and "base-config.yaml" in meta.extra_files:
@ -188,11 +177,9 @@ if meta.webapp:
async def _handle_plugin_request(req: web.Request) -> web.StreamResponse:
if req.path.startswith(web_base_path):
req = req.clone(
rel_url=req.rel_url.with_path(req.rel_url.path[len(web_base_path) :]).with_query(
req.query_string
)
)
path_override = req.rel_url.path[len(web_base_path) :]
url_override = req.rel_url.with_path(path_override).with_query(req.query_string)
req = req.clone(rel_url=url_override)
return await plugin_webapp.handle(req)
return web.Response(status=404)
@ -213,6 +200,9 @@ async def main():
global client, bot
await db.start()
nb = await NextBatch(db, user_id).load()
client_log = logging.getLogger("maubot.client").getChild(user_id)
client = MaubotMatrixClient(
mxid=user_id,
@ -221,15 +211,15 @@ async def main():
client_session=http_client,
loop=loop,
log=client_log,
sync_store=SyncStoreProxy(nb),
sync_store=nb,
state_store=state_store,
device_id=device_id,
)
client.ignore_first_sync = config["user.ignore_first_sync"]
client.ignore_initial_sync = config["user.ignore_initial_sync"]
if crypto_store:
await crypto_db.start()
await state_store.upgrade_table.upgrade(crypto_db)
await crypto_store.upgrade_table.upgrade(db)
await state_store.upgrade_table.upgrade(db)
await crypto_store.open()
client.crypto = OlmMachine(client, crypto_store, state_store)
@ -254,8 +244,11 @@ async def main():
while True:
try:
whoami = await client.whoami()
except Exception:
log.exception("Failed to connect to homeserver, retrying in 10 seconds...")
except Exception as e:
log.error(
f"Failed to connect to homeserver: {type(e).__name__}: {e}"
" - retrying in 10 seconds..."
)
await asyncio.sleep(10)
continue
if whoami.user_id != user_id:
@ -272,14 +265,11 @@ async def main():
if config["user.sync"]:
if not nb.filter_id:
nb.edit(
filter_id=await client.create_filter(
Filter(
room=RoomFilter(timeline=RoomEventFilter(limit=50)),
filter_id = await client.create_filter(
Filter(room=RoomFilter(timeline=RoomEventFilter(limit=50)))
)
)
)
client.start(nb.filter_id)
await nb.put_filter_id(filter_id)
_ = client.start(nb.filter_id)
if config["user.autojoin"]:
log.debug("Autojoin is enabled")
@ -321,25 +311,28 @@ async def stop(suppress_stop_error: bool = False) -> None:
except Exception:
if not suppress_stop_error:
log.exception("Error stopping bot")
if crypto_db:
await crypto_db.stop()
if web_runner:
await web_runner.shutdown()
await web_runner.cleanup()
await db.stop()
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
try:
log.info("Starting plugin")
loop.run_until_complete(main())
except Exception:
except (Exception, KeyboardInterrupt) as e:
if isinstance(e, KeyboardInterrupt):
log.info("Startup interrupted, stopping")
else:
log.fatal("Failed to start plugin", exc_info=True)
loop.run_until_complete(stop(suppress_stop_error=True))
loop.close()
sys.exit(1)
signal.signal(signal.SIGINT, signal.default_int_handler)
signal.signal(signal.SIGTERM, signal.default_int_handler)
try:
log.info("Startup completed, running forever")
loop.run_forever()

View file

@ -41,8 +41,8 @@ class Config(BaseFileConfig):
copy("server.port")
copy("server.base_path")
copy("server.public_url")
if "database" in base:
copy("database")
copy("database_opts")
if "plugin_config" in base:
copy("plugin_config")
copy("logging")

View file

@ -15,19 +15,41 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
import sqlalchemy as sql
from attr import dataclass
from mautrix.client import SyncStore
from mautrix.types import FilterID, SyncToken, UserID
from mautrix.util.db import Base
from mautrix.util.async_db import Database
find_q = "SELECT next_batch, filter_id FROM standalone_next_batch WHERE user_id=$1"
insert_q = "INSERT INTO standalone_next_batch (user_id, next_batch, filter_id) VALUES ($1, $2, $3)"
update_nb_q = "UPDATE standalone_next_batch SET next_batch=$1 WHERE user_id=$2"
update_filter_q = "UPDATE standalone_next_batch SET filter_id=$1 WHERE user_id=$2"
class NextBatch(Base):
__tablename__ = "standalone_next_batch"
@dataclass
class NextBatch(SyncStore):
db: Database
user_id: UserID
next_batch: SyncToken = ""
filter_id: FilterID = ""
user_id: UserID = sql.Column(sql.String(255), primary_key=True)
next_batch: SyncToken = sql.Column(sql.String(255))
filter_id: FilterID = sql.Column(sql.String(255))
async def load(self) -> NextBatch:
row = await self.db.fetchrow(find_q, self.user_id)
if row is not None:
self.next_batch = row["next_batch"]
self.filter_id = row["filter_id"]
else:
await self.db.execute(insert_q, self.user_id, self.next_batch, self.filter_id)
return self
@classmethod
def get(cls, user_id: UserID) -> NextBatch | None:
return cls._select_one_or_none(cls.c.user_id == user_id)
async def put_filter_id(self, filter_id: FilterID) -> None:
self.filter_id = filter_id
await self.db.execute(update_filter_q, self.filter_id, self.user_id)
async def put_next_batch(self, next_batch: SyncToken) -> None:
self.next_batch = next_batch
await self.db.execute(update_nb_q, self.next_batch, self.user_id)
async def get_next_batch(self) -> SyncToken:
return self.next_batch

View file

@ -37,6 +37,14 @@ server:
# SQLite and Postgres are supported.
database: sqlite:///bot.db
# Additional arguments for asyncpg.create_pool() or sqlite3.connect()
# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool
# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
# For sqlite, min_size is used as the connection thread pool size and max_size is ignored.
database_opts:
min_size: 1
max_size: 10
# Config for the plugin. Refer to the plugin's base-config.yaml to find what (if anything) to put here.
plugin_config: {}