Update standalone mode to asyncpg/aiosqlite
This commit is contained in:
parent
c4f9a3bdf5
commit
29b4a3c892
4 changed files with 80 additions and 57 deletions
|
@ -29,24 +29,20 @@ from aiohttp import ClientSession, hdrs, web
|
|||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.comments import CommentedMap
|
||||
from yarl import URL
|
||||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import (
|
||||
EventType,
|
||||
Filter,
|
||||
FilterID,
|
||||
Membership,
|
||||
RoomEventFilter,
|
||||
RoomFilter,
|
||||
StrippedStateEvent,
|
||||
SyncToken,
|
||||
)
|
||||
from mautrix.util.async_db import Database
|
||||
from mautrix.util.config import BaseMissingError, RecursiveDict
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
||||
from ..__meta__ import __version__
|
||||
from ..lib.store_proxy import SyncStoreProxy
|
||||
from ..loader import PluginMeta
|
||||
from ..matrix import MaubotMatrixClient
|
||||
from ..plugin_base import Plugin
|
||||
|
@ -60,10 +56,9 @@ crypto_import_error = None
|
|||
|
||||
try:
|
||||
from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore
|
||||
from mautrix.util.async_db import Database as AsyncDatabase
|
||||
except ImportError as err:
|
||||
crypto_import_error = err
|
||||
OlmMachine = AsyncDatabase = PgCryptoStateStore = PgCryptoStore = None
|
||||
OlmMachine = PgCryptoStateStore = PgCryptoStore = None
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A plugin-based Matrix bot system -- standalone mode.",
|
||||
|
@ -124,32 +119,26 @@ loader = FileSystemLoader(os.path.dirname(args.meta))
|
|||
|
||||
log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}")
|
||||
|
||||
log.debug("Opening database")
|
||||
db = sql.create_engine(config["database"])
|
||||
Base.metadata.bind = db
|
||||
Base.metadata.create_all()
|
||||
NextBatch.bind(db)
|
||||
db = Database.create(
|
||||
config["database"],
|
||||
db_args=config.get("database_opts", {}),
|
||||
ignore_foreign_tables=True,
|
||||
)
|
||||
|
||||
user_id = config["user.credentials.id"]
|
||||
device_id = config["user.credentials.device_id"]
|
||||
homeserver = config["user.credentials.homeserver"]
|
||||
access_token = config["user.credentials.access_token"]
|
||||
|
||||
crypto_store = crypto_db = state_store = None
|
||||
crypto_store = state_store = None
|
||||
if device_id and not OlmMachine:
|
||||
log.warning(
|
||||
"device_id set in config, but encryption dependencies not installed",
|
||||
exc_info=crypto_import_error,
|
||||
)
|
||||
elif device_id:
|
||||
crypto_db = AsyncDatabase.create(config["database"], upgrade_table=PgCryptoStore.upgrade_table)
|
||||
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=crypto_db)
|
||||
state_store = PgCryptoStateStore(crypto_db)
|
||||
|
||||
nb = NextBatch.get(user_id)
|
||||
if not nb:
|
||||
nb = NextBatch(user_id=user_id, next_batch=SyncToken(""), filter_id=FilterID(""))
|
||||
nb.insert()
|
||||
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=db)
|
||||
state_store = PgCryptoStateStore(db)
|
||||
|
||||
bot_config = None
|
||||
if not meta.config and "base-config.yaml" in meta.extra_files:
|
||||
|
@ -188,11 +177,9 @@ if meta.webapp:
|
|||
|
||||
async def _handle_plugin_request(req: web.Request) -> web.StreamResponse:
|
||||
if req.path.startswith(web_base_path):
|
||||
req = req.clone(
|
||||
rel_url=req.rel_url.with_path(req.rel_url.path[len(web_base_path) :]).with_query(
|
||||
req.query_string
|
||||
)
|
||||
)
|
||||
path_override = req.rel_url.path[len(web_base_path) :]
|
||||
url_override = req.rel_url.with_path(path_override).with_query(req.query_string)
|
||||
req = req.clone(rel_url=url_override)
|
||||
return await plugin_webapp.handle(req)
|
||||
return web.Response(status=404)
|
||||
|
||||
|
@ -213,6 +200,9 @@ async def main():
|
|||
|
||||
global client, bot
|
||||
|
||||
await db.start()
|
||||
nb = await NextBatch(db, user_id).load()
|
||||
|
||||
client_log = logging.getLogger("maubot.client").getChild(user_id)
|
||||
client = MaubotMatrixClient(
|
||||
mxid=user_id,
|
||||
|
@ -221,15 +211,15 @@ async def main():
|
|||
client_session=http_client,
|
||||
loop=loop,
|
||||
log=client_log,
|
||||
sync_store=SyncStoreProxy(nb),
|
||||
sync_store=nb,
|
||||
state_store=state_store,
|
||||
device_id=device_id,
|
||||
)
|
||||
client.ignore_first_sync = config["user.ignore_first_sync"]
|
||||
client.ignore_initial_sync = config["user.ignore_initial_sync"]
|
||||
if crypto_store:
|
||||
await crypto_db.start()
|
||||
await state_store.upgrade_table.upgrade(crypto_db)
|
||||
await crypto_store.upgrade_table.upgrade(db)
|
||||
await state_store.upgrade_table.upgrade(db)
|
||||
await crypto_store.open()
|
||||
|
||||
client.crypto = OlmMachine(client, crypto_store, state_store)
|
||||
|
@ -254,8 +244,11 @@ async def main():
|
|||
while True:
|
||||
try:
|
||||
whoami = await client.whoami()
|
||||
except Exception:
|
||||
log.exception("Failed to connect to homeserver, retrying in 10 seconds...")
|
||||
except Exception as e:
|
||||
log.error(
|
||||
f"Failed to connect to homeserver: {type(e).__name__}: {e}"
|
||||
" - retrying in 10 seconds..."
|
||||
)
|
||||
await asyncio.sleep(10)
|
||||
continue
|
||||
if whoami.user_id != user_id:
|
||||
|
@ -272,14 +265,11 @@ async def main():
|
|||
|
||||
if config["user.sync"]:
|
||||
if not nb.filter_id:
|
||||
nb.edit(
|
||||
filter_id=await client.create_filter(
|
||||
Filter(
|
||||
room=RoomFilter(timeline=RoomEventFilter(limit=50)),
|
||||
)
|
||||
)
|
||||
filter_id = await client.create_filter(
|
||||
Filter(room=RoomFilter(timeline=RoomEventFilter(limit=50)))
|
||||
)
|
||||
client.start(nb.filter_id)
|
||||
await nb.put_filter_id(filter_id)
|
||||
_ = client.start(nb.filter_id)
|
||||
|
||||
if config["user.autojoin"]:
|
||||
log.debug("Autojoin is enabled")
|
||||
|
@ -321,25 +311,28 @@ async def stop(suppress_stop_error: bool = False) -> None:
|
|||
except Exception:
|
||||
if not suppress_stop_error:
|
||||
log.exception("Error stopping bot")
|
||||
if crypto_db:
|
||||
await crypto_db.stop()
|
||||
if web_runner:
|
||||
await web_runner.shutdown()
|
||||
await web_runner.cleanup()
|
||||
await db.stop()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
signal.signal(signal.SIGTERM, signal.default_int_handler)
|
||||
|
||||
|
||||
try:
|
||||
log.info("Starting plugin")
|
||||
loop.run_until_complete(main())
|
||||
except Exception:
|
||||
log.fatal("Failed to start plugin", exc_info=True)
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
if isinstance(e, KeyboardInterrupt):
|
||||
log.info("Startup interrupted, stopping")
|
||||
else:
|
||||
log.fatal("Failed to start plugin", exc_info=True)
|
||||
loop.run_until_complete(stop(suppress_stop_error=True))
|
||||
loop.close()
|
||||
sys.exit(1)
|
||||
|
||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
signal.signal(signal.SIGTERM, signal.default_int_handler)
|
||||
|
||||
try:
|
||||
log.info("Startup completed, running forever")
|
||||
loop.run_forever()
|
||||
|
|
|
@ -41,8 +41,8 @@ class Config(BaseFileConfig):
|
|||
copy("server.port")
|
||||
copy("server.base_path")
|
||||
copy("server.public_url")
|
||||
if "database" in base:
|
||||
copy("database")
|
||||
copy("database")
|
||||
copy("database_opts")
|
||||
if "plugin_config" in base:
|
||||
copy("plugin_config")
|
||||
copy("logging")
|
||||
|
|
|
@ -15,19 +15,41 @@
|
|||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sql
|
||||
from attr import dataclass
|
||||
|
||||
from mautrix.client import SyncStore
|
||||
from mautrix.types import FilterID, SyncToken, UserID
|
||||
from mautrix.util.db import Base
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
find_q = "SELECT next_batch, filter_id FROM standalone_next_batch WHERE user_id=$1"
|
||||
insert_q = "INSERT INTO standalone_next_batch (user_id, next_batch, filter_id) VALUES ($1, $2, $3)"
|
||||
update_nb_q = "UPDATE standalone_next_batch SET next_batch=$1 WHERE user_id=$2"
|
||||
update_filter_q = "UPDATE standalone_next_batch SET filter_id=$1 WHERE user_id=$2"
|
||||
|
||||
|
||||
class NextBatch(Base):
|
||||
__tablename__ = "standalone_next_batch"
|
||||
@dataclass
|
||||
class NextBatch(SyncStore):
|
||||
db: Database
|
||||
user_id: UserID
|
||||
next_batch: SyncToken = ""
|
||||
filter_id: FilterID = ""
|
||||
|
||||
user_id: UserID = sql.Column(sql.String(255), primary_key=True)
|
||||
next_batch: SyncToken = sql.Column(sql.String(255))
|
||||
filter_id: FilterID = sql.Column(sql.String(255))
|
||||
async def load(self) -> NextBatch:
|
||||
row = await self.db.fetchrow(find_q, self.user_id)
|
||||
if row is not None:
|
||||
self.next_batch = row["next_batch"]
|
||||
self.filter_id = row["filter_id"]
|
||||
else:
|
||||
await self.db.execute(insert_q, self.user_id, self.next_batch, self.filter_id)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def get(cls, user_id: UserID) -> NextBatch | None:
|
||||
return cls._select_one_or_none(cls.c.user_id == user_id)
|
||||
async def put_filter_id(self, filter_id: FilterID) -> None:
|
||||
self.filter_id = filter_id
|
||||
await self.db.execute(update_filter_q, self.filter_id, self.user_id)
|
||||
|
||||
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
||||
self.next_batch = next_batch
|
||||
await self.db.execute(update_nb_q, self.next_batch, self.user_id)
|
||||
|
||||
async def get_next_batch(self) -> SyncToken:
|
||||
return self.next_batch
|
||||
|
|
|
@ -37,6 +37,14 @@ server:
|
|||
# SQLite and Postgres are supported.
|
||||
database: sqlite:///bot.db
|
||||
|
||||
# Additional arguments for asyncpg.create_pool() or sqlite3.connect()
|
||||
# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool
|
||||
# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
|
||||
# For sqlite, min_size is used as the connection thread pool size and max_size is ignored.
|
||||
database_opts:
|
||||
min_size: 1
|
||||
max_size: 10
|
||||
|
||||
# Config for the plugin. Refer to the plugin's base-config.yaml to find what (if anything) to put here.
|
||||
plugin_config: {}
|
||||
|
||||
|
|
Loading…
Reference in a new issue