Update standalone mode to asyncpg/aiosqlite
This commit is contained in:
parent
c4f9a3bdf5
commit
29b4a3c892
4 changed files with 80 additions and 57 deletions
|
@ -29,24 +29,20 @@ from aiohttp import ClientSession, hdrs, web
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
from ruamel.yaml.comments import CommentedMap
|
from ruamel.yaml.comments import CommentedMap
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
import sqlalchemy as sql
|
|
||||||
|
|
||||||
from mautrix.types import (
|
from mautrix.types import (
|
||||||
EventType,
|
EventType,
|
||||||
Filter,
|
Filter,
|
||||||
FilterID,
|
|
||||||
Membership,
|
Membership,
|
||||||
RoomEventFilter,
|
RoomEventFilter,
|
||||||
RoomFilter,
|
RoomFilter,
|
||||||
StrippedStateEvent,
|
StrippedStateEvent,
|
||||||
SyncToken,
|
|
||||||
)
|
)
|
||||||
|
from mautrix.util.async_db import Database
|
||||||
from mautrix.util.config import BaseMissingError, RecursiveDict
|
from mautrix.util.config import BaseMissingError, RecursiveDict
|
||||||
from mautrix.util.db import Base
|
|
||||||
from mautrix.util.logging import TraceLogger
|
from mautrix.util.logging import TraceLogger
|
||||||
|
|
||||||
from ..__meta__ import __version__
|
from ..__meta__ import __version__
|
||||||
from ..lib.store_proxy import SyncStoreProxy
|
|
||||||
from ..loader import PluginMeta
|
from ..loader import PluginMeta
|
||||||
from ..matrix import MaubotMatrixClient
|
from ..matrix import MaubotMatrixClient
|
||||||
from ..plugin_base import Plugin
|
from ..plugin_base import Plugin
|
||||||
|
@ -60,10 +56,9 @@ crypto_import_error = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore
|
from mautrix.crypto import OlmMachine, PgCryptoStateStore, PgCryptoStore
|
||||||
from mautrix.util.async_db import Database as AsyncDatabase
|
|
||||||
except ImportError as err:
|
except ImportError as err:
|
||||||
crypto_import_error = err
|
crypto_import_error = err
|
||||||
OlmMachine = AsyncDatabase = PgCryptoStateStore = PgCryptoStore = None
|
OlmMachine = PgCryptoStateStore = PgCryptoStore = None
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="A plugin-based Matrix bot system -- standalone mode.",
|
description="A plugin-based Matrix bot system -- standalone mode.",
|
||||||
|
@ -124,32 +119,26 @@ loader = FileSystemLoader(os.path.dirname(args.meta))
|
||||||
|
|
||||||
log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}")
|
log.info(f"Initializing standalone {meta.id} v{meta.version} on maubot {__version__}")
|
||||||
|
|
||||||
log.debug("Opening database")
|
db = Database.create(
|
||||||
db = sql.create_engine(config["database"])
|
config["database"],
|
||||||
Base.metadata.bind = db
|
db_args=config.get("database_opts", {}),
|
||||||
Base.metadata.create_all()
|
ignore_foreign_tables=True,
|
||||||
NextBatch.bind(db)
|
)
|
||||||
|
|
||||||
user_id = config["user.credentials.id"]
|
user_id = config["user.credentials.id"]
|
||||||
device_id = config["user.credentials.device_id"]
|
device_id = config["user.credentials.device_id"]
|
||||||
homeserver = config["user.credentials.homeserver"]
|
homeserver = config["user.credentials.homeserver"]
|
||||||
access_token = config["user.credentials.access_token"]
|
access_token = config["user.credentials.access_token"]
|
||||||
|
|
||||||
crypto_store = crypto_db = state_store = None
|
crypto_store = state_store = None
|
||||||
if device_id and not OlmMachine:
|
if device_id and not OlmMachine:
|
||||||
log.warning(
|
log.warning(
|
||||||
"device_id set in config, but encryption dependencies not installed",
|
"device_id set in config, but encryption dependencies not installed",
|
||||||
exc_info=crypto_import_error,
|
exc_info=crypto_import_error,
|
||||||
)
|
)
|
||||||
elif device_id:
|
elif device_id:
|
||||||
crypto_db = AsyncDatabase.create(config["database"], upgrade_table=PgCryptoStore.upgrade_table)
|
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=db)
|
||||||
crypto_store = PgCryptoStore(account_id=user_id, pickle_key="mau.crypto", db=crypto_db)
|
state_store = PgCryptoStateStore(db)
|
||||||
state_store = PgCryptoStateStore(crypto_db)
|
|
||||||
|
|
||||||
nb = NextBatch.get(user_id)
|
|
||||||
if not nb:
|
|
||||||
nb = NextBatch(user_id=user_id, next_batch=SyncToken(""), filter_id=FilterID(""))
|
|
||||||
nb.insert()
|
|
||||||
|
|
||||||
bot_config = None
|
bot_config = None
|
||||||
if not meta.config and "base-config.yaml" in meta.extra_files:
|
if not meta.config and "base-config.yaml" in meta.extra_files:
|
||||||
|
@ -188,11 +177,9 @@ if meta.webapp:
|
||||||
|
|
||||||
async def _handle_plugin_request(req: web.Request) -> web.StreamResponse:
|
async def _handle_plugin_request(req: web.Request) -> web.StreamResponse:
|
||||||
if req.path.startswith(web_base_path):
|
if req.path.startswith(web_base_path):
|
||||||
req = req.clone(
|
path_override = req.rel_url.path[len(web_base_path) :]
|
||||||
rel_url=req.rel_url.with_path(req.rel_url.path[len(web_base_path) :]).with_query(
|
url_override = req.rel_url.with_path(path_override).with_query(req.query_string)
|
||||||
req.query_string
|
req = req.clone(rel_url=url_override)
|
||||||
)
|
|
||||||
)
|
|
||||||
return await plugin_webapp.handle(req)
|
return await plugin_webapp.handle(req)
|
||||||
return web.Response(status=404)
|
return web.Response(status=404)
|
||||||
|
|
||||||
|
@ -213,6 +200,9 @@ async def main():
|
||||||
|
|
||||||
global client, bot
|
global client, bot
|
||||||
|
|
||||||
|
await db.start()
|
||||||
|
nb = await NextBatch(db, user_id).load()
|
||||||
|
|
||||||
client_log = logging.getLogger("maubot.client").getChild(user_id)
|
client_log = logging.getLogger("maubot.client").getChild(user_id)
|
||||||
client = MaubotMatrixClient(
|
client = MaubotMatrixClient(
|
||||||
mxid=user_id,
|
mxid=user_id,
|
||||||
|
@ -221,15 +211,15 @@ async def main():
|
||||||
client_session=http_client,
|
client_session=http_client,
|
||||||
loop=loop,
|
loop=loop,
|
||||||
log=client_log,
|
log=client_log,
|
||||||
sync_store=SyncStoreProxy(nb),
|
sync_store=nb,
|
||||||
state_store=state_store,
|
state_store=state_store,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
)
|
)
|
||||||
client.ignore_first_sync = config["user.ignore_first_sync"]
|
client.ignore_first_sync = config["user.ignore_first_sync"]
|
||||||
client.ignore_initial_sync = config["user.ignore_initial_sync"]
|
client.ignore_initial_sync = config["user.ignore_initial_sync"]
|
||||||
if crypto_store:
|
if crypto_store:
|
||||||
await crypto_db.start()
|
await crypto_store.upgrade_table.upgrade(db)
|
||||||
await state_store.upgrade_table.upgrade(crypto_db)
|
await state_store.upgrade_table.upgrade(db)
|
||||||
await crypto_store.open()
|
await crypto_store.open()
|
||||||
|
|
||||||
client.crypto = OlmMachine(client, crypto_store, state_store)
|
client.crypto = OlmMachine(client, crypto_store, state_store)
|
||||||
|
@ -254,8 +244,11 @@ async def main():
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
whoami = await client.whoami()
|
whoami = await client.whoami()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
log.exception("Failed to connect to homeserver, retrying in 10 seconds...")
|
log.error(
|
||||||
|
f"Failed to connect to homeserver: {type(e).__name__}: {e}"
|
||||||
|
" - retrying in 10 seconds..."
|
||||||
|
)
|
||||||
await asyncio.sleep(10)
|
await asyncio.sleep(10)
|
||||||
continue
|
continue
|
||||||
if whoami.user_id != user_id:
|
if whoami.user_id != user_id:
|
||||||
|
@ -272,14 +265,11 @@ async def main():
|
||||||
|
|
||||||
if config["user.sync"]:
|
if config["user.sync"]:
|
||||||
if not nb.filter_id:
|
if not nb.filter_id:
|
||||||
nb.edit(
|
filter_id = await client.create_filter(
|
||||||
filter_id=await client.create_filter(
|
Filter(room=RoomFilter(timeline=RoomEventFilter(limit=50)))
|
||||||
Filter(
|
|
||||||
room=RoomFilter(timeline=RoomEventFilter(limit=50)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
client.start(nb.filter_id)
|
await nb.put_filter_id(filter_id)
|
||||||
|
_ = client.start(nb.filter_id)
|
||||||
|
|
||||||
if config["user.autojoin"]:
|
if config["user.autojoin"]:
|
||||||
log.debug("Autojoin is enabled")
|
log.debug("Autojoin is enabled")
|
||||||
|
@ -321,25 +311,28 @@ async def stop(suppress_stop_error: bool = False) -> None:
|
||||||
except Exception:
|
except Exception:
|
||||||
if not suppress_stop_error:
|
if not suppress_stop_error:
|
||||||
log.exception("Error stopping bot")
|
log.exception("Error stopping bot")
|
||||||
if crypto_db:
|
|
||||||
await crypto_db.stop()
|
|
||||||
if web_runner:
|
if web_runner:
|
||||||
await web_runner.shutdown()
|
await web_runner.shutdown()
|
||||||
await web_runner.cleanup()
|
await web_runner.cleanup()
|
||||||
|
await db.stop()
|
||||||
|
|
||||||
|
|
||||||
|
signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||||
|
signal.signal(signal.SIGTERM, signal.default_int_handler)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info("Starting plugin")
|
log.info("Starting plugin")
|
||||||
loop.run_until_complete(main())
|
loop.run_until_complete(main())
|
||||||
except Exception:
|
except (Exception, KeyboardInterrupt) as e:
|
||||||
log.fatal("Failed to start plugin", exc_info=True)
|
if isinstance(e, KeyboardInterrupt):
|
||||||
|
log.info("Startup interrupted, stopping")
|
||||||
|
else:
|
||||||
|
log.fatal("Failed to start plugin", exc_info=True)
|
||||||
loop.run_until_complete(stop(suppress_stop_error=True))
|
loop.run_until_complete(stop(suppress_stop_error=True))
|
||||||
loop.close()
|
loop.close()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal.default_int_handler)
|
|
||||||
signal.signal(signal.SIGTERM, signal.default_int_handler)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info("Startup completed, running forever")
|
log.info("Startup completed, running forever")
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
|
|
@ -41,8 +41,8 @@ class Config(BaseFileConfig):
|
||||||
copy("server.port")
|
copy("server.port")
|
||||||
copy("server.base_path")
|
copy("server.base_path")
|
||||||
copy("server.public_url")
|
copy("server.public_url")
|
||||||
if "database" in base:
|
copy("database")
|
||||||
copy("database")
|
copy("database_opts")
|
||||||
if "plugin_config" in base:
|
if "plugin_config" in base:
|
||||||
copy("plugin_config")
|
copy("plugin_config")
|
||||||
copy("logging")
|
copy("logging")
|
||||||
|
|
|
@ -15,19 +15,41 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sqlalchemy as sql
|
from attr import dataclass
|
||||||
|
|
||||||
|
from mautrix.client import SyncStore
|
||||||
from mautrix.types import FilterID, SyncToken, UserID
|
from mautrix.types import FilterID, SyncToken, UserID
|
||||||
from mautrix.util.db import Base
|
from mautrix.util.async_db import Database
|
||||||
|
|
||||||
|
find_q = "SELECT next_batch, filter_id FROM standalone_next_batch WHERE user_id=$1"
|
||||||
|
insert_q = "INSERT INTO standalone_next_batch (user_id, next_batch, filter_id) VALUES ($1, $2, $3)"
|
||||||
|
update_nb_q = "UPDATE standalone_next_batch SET next_batch=$1 WHERE user_id=$2"
|
||||||
|
update_filter_q = "UPDATE standalone_next_batch SET filter_id=$1 WHERE user_id=$2"
|
||||||
|
|
||||||
|
|
||||||
class NextBatch(Base):
|
@dataclass
|
||||||
__tablename__ = "standalone_next_batch"
|
class NextBatch(SyncStore):
|
||||||
|
db: Database
|
||||||
|
user_id: UserID
|
||||||
|
next_batch: SyncToken = ""
|
||||||
|
filter_id: FilterID = ""
|
||||||
|
|
||||||
user_id: UserID = sql.Column(sql.String(255), primary_key=True)
|
async def load(self) -> NextBatch:
|
||||||
next_batch: SyncToken = sql.Column(sql.String(255))
|
row = await self.db.fetchrow(find_q, self.user_id)
|
||||||
filter_id: FilterID = sql.Column(sql.String(255))
|
if row is not None:
|
||||||
|
self.next_batch = row["next_batch"]
|
||||||
|
self.filter_id = row["filter_id"]
|
||||||
|
else:
|
||||||
|
await self.db.execute(insert_q, self.user_id, self.next_batch, self.filter_id)
|
||||||
|
return self
|
||||||
|
|
||||||
@classmethod
|
async def put_filter_id(self, filter_id: FilterID) -> None:
|
||||||
def get(cls, user_id: UserID) -> NextBatch | None:
|
self.filter_id = filter_id
|
||||||
return cls._select_one_or_none(cls.c.user_id == user_id)
|
await self.db.execute(update_filter_q, self.filter_id, self.user_id)
|
||||||
|
|
||||||
|
async def put_next_batch(self, next_batch: SyncToken) -> None:
|
||||||
|
self.next_batch = next_batch
|
||||||
|
await self.db.execute(update_nb_q, self.next_batch, self.user_id)
|
||||||
|
|
||||||
|
async def get_next_batch(self) -> SyncToken:
|
||||||
|
return self.next_batch
|
||||||
|
|
|
@ -37,6 +37,14 @@ server:
|
||||||
# SQLite and Postgres are supported.
|
# SQLite and Postgres are supported.
|
||||||
database: sqlite:///bot.db
|
database: sqlite:///bot.db
|
||||||
|
|
||||||
|
# Additional arguments for asyncpg.create_pool() or sqlite3.connect()
|
||||||
|
# https://magicstack.github.io/asyncpg/current/api/index.html#asyncpg.pool.create_pool
|
||||||
|
# https://docs.python.org/3/library/sqlite3.html#sqlite3.connect
|
||||||
|
# For sqlite, min_size is used as the connection thread pool size and max_size is ignored.
|
||||||
|
database_opts:
|
||||||
|
min_size: 1
|
||||||
|
max_size: 10
|
||||||
|
|
||||||
# Config for the plugin. Refer to the plugin's base-config.yaml to find what (if anything) to put here.
|
# Config for the plugin. Refer to the plugin's base-config.yaml to find what (if anything) to put here.
|
||||||
plugin_config: {}
|
plugin_config: {}
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue