Add support for asyncpg plugin databases

This commit is contained in:
Tulir Asokan 2022-03-26 13:59:49 +02:00
parent 4b234e4d34
commit 4d8e1475e6
12 changed files with 258 additions and 39 deletions

View file

@ -17,8 +17,8 @@ function fixconfig {
fixdefault '.plugin_directories.upload' './plugins' '/data/plugins'
fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins'
fixdefault '.plugin_directories.trash' './trash' '/data/trash'
fixdefault '.plugin_directories.db' './plugins' '/data/dbs'
fixdefault '.plugin_directories.db' './dbs' '/data/dbs'
fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs'
fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs'
fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log'
# This doesn't need to be configurable
yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml

View file

@ -18,7 +18,7 @@ from __future__ import annotations
import asyncio
import sys
from mautrix.util.async_db import Database, DatabaseException
from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme
from mautrix.util.program import Program
from .__meta__ import __version__
@ -43,6 +43,7 @@ class Maubot(Program):
server: MaubotServer
db: Database
crypto_db: Database | None
plugin_postgres_db: PostgresDatabase | None
state_store: PgStateStore
config_class = Config
@ -71,13 +72,7 @@ class Maubot(Program):
help="Run even if the database contains tables from other programs (like Synapse)",
)
def prepare(self) -> None:
super().prepare()
if self.config["api_features.log"]:
self.prepare_log_websocket()
init_zip_loader(self.config)
def prepare_db(self) -> None:
self.db = Database.create(
self.config["database"],
upgrade_table=upgrade_table,
@ -86,6 +81,7 @@ class Maubot(Program):
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
init_db(self.db)
if self.config["crypto_database"] == "default":
self.crypto_db = self.db
else:
@ -94,6 +90,40 @@ class Maubot(Program):
upgrade_table=PgCryptoStore.upgrade_table,
ignore_foreign_tables=self.args.ignore_foreign_tables,
)
if self.config["plugin_databases.postgres"] == "default":
if self.db.scheme != Scheme.POSTGRES:
self.log.critical(
'Using "default" as the postgres plugin database URL is only allowed if '
"the default database is postgres."
)
sys.exit(24)
assert isinstance(self.db, PostgresDatabase)
self.plugin_postgres_db = self.db
elif self.config["plugin_databases.postgres"]:
plugin_db = Database.create(
self.config["plugin_databases.postgres"],
db_args={
**self.config["database_opts"],
**self.config["plugin_databases.postgres_opts"],
},
)
if plugin_db.scheme != Scheme.POSTGRES:
self.log.critical("The plugin postgres database URL must be a postgres database")
sys.exit(24)
assert isinstance(plugin_db, PostgresDatabase)
self.plugin_postgres_db = plugin_db
else:
self.plugin_postgres_db = None
def prepare(self) -> None:
super().prepare()
if self.config["api_features.log"]:
self.prepare_log_websocket()
init_zip_loader(self.config)
self.prepare_db()
Client.init_cls(self)
PluginInstance.init_cls(self)
management_api = init_mgmt_api(self.config, self.loop)

View file

@ -42,7 +42,12 @@ class Config(BaseFileConfig):
copy("plugin_directories.upload")
copy("plugin_directories.load")
copy("plugin_directories.trash")
copy("plugin_directories.db")
if "plugin_directories.db" in self:
base["plugin_databases.sqlite"] = self["plugin_directories.db"]
else:
copy("plugin_databases.sqlite")
copy("plugin_databases.postgres")
copy("plugin_databases.postgres_opts")
copy("server.hostname")
copy("server.port")
copy("server.public_url")

View file

@ -16,6 +16,7 @@ database_opts:
min_size: 1
max_size: 10
# Configuration for storing plugin .mbp files
plugin_directories:
# The directory where uploaded new plugins should be stored.
upload: ./plugins
@ -26,8 +27,27 @@ plugin_directories:
# The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately.
trash: ./trash
# The directory where plugin databases should be stored.
db: ./plugins
# Configuration for storing plugin databases
plugin_databases:
# The directory where SQLite plugin databases should be stored.
sqlite: ./plugins
# The connection URL for plugin databases. If null, all plugins will get SQLite databases.
# If set, plugins using the new asyncpg interface will get a Postgres connection instead.
# Plugins using the legacy SQLAlchemy interface will always get a SQLite connection.
#
# To use the same connection pool as the default database, set to "default"
# (the default database above must be postgres to do this).
#
# When enabled, maubot will create separate Postgres schemas in the database for each plugin.
# To view schemas in psql, use `\dn`. To view enter and interact with a specific schema,
# use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal
# SQL queries/psql commands.
postgres: null
# Maximum number of connections per plugin instance.
postgres_max_conns_per_plugin: 3
# Overrides for the default database_opts when using a non-"default" postgres connection string.
postgres_opts: {}
server:
# The IP and port to listen to.

View file

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
from collections import defaultdict
import asyncio
import inspect
@ -28,19 +28,23 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from mautrix.types import UserID
from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.util.logging import TraceLogger
from .client import Client
from .db import Instance as DBInstance
from .loader import PluginLoader, ZippedPluginLoader
from .lib.plugin_db import ProxyPostgresDatabase
from .loader import DatabaseType, PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin
if TYPE_CHECKING:
from .__main__ import Maubot
from .server import PluginWebApp
log = logging.getLogger("maubot.instance")
log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance"))
db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db"))
yaml = YAML()
yaml.indent(4)
@ -60,7 +64,7 @@ class PluginInstance(DBInstance):
config: BaseProxyConfig | None
base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None
inst_db: sql.engine.Engine | None
inst_db: sql.engine.Engine | Database | None
inst_db_tables: dict[str, sql.Table] | None
inst_webapp: PluginWebApp | None
inst_webapp_url: str | None
@ -130,8 +134,6 @@ class PluginInstance(DBInstance):
self.log.error(f"Failed to get client for user {self.primary_user}")
await self.update_enabled(False)
return False
if self.loader.meta.database:
self.enable_database()
if self.loader.meta.webapp:
self.enable_webapp()
self.log.debug("Plugin instance dependencies loaded")
@ -147,9 +149,9 @@ class PluginInstance(DBInstance):
self.inst_webapp = None
self.inst_webapp_url = None
def enable_database(self) -> None:
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
@property
def _sqlite_db_path(self) -> str:
return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db")
async def delete(self) -> None:
if self.loader is not None:
@ -162,11 +164,8 @@ class PluginInstance(DBInstance):
pass
await super().delete()
if self.inst_db:
self.inst_db.dispose()
ZippedPluginLoader.trash(
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted",
)
await self.stop_database()
await self.delete_database()
if self.inst_webapp:
self.disable_webapp()
@ -178,6 +177,56 @@ class PluginInstance(DBInstance):
yaml.dump(data, buf)
self.config_str = buf.getvalue()
async def start_database(
self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True
) -> None:
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
self.inst_db = sql.create_engine(f"sqlite:///{self._sqlite_db_path}")
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
instance_db_log = db_log.getChild(self.id)
# TODO should there be a way to choose between SQLite and Postgres
# for individual instances? Maybe checking the existence of the SQLite file.
if self.maubot.plugin_postgres_db:
self.inst_db = ProxyPostgresDatabase(
pool=self.maubot.plugin_postgres_db,
instance_id=self.id,
max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"],
upgrade_table=upgrade_table,
log=instance_db_log,
)
else:
self.inst_db = Database.create(
f"sqlite:///{self._sqlite_db_path}",
upgrade_table=upgrade_table,
log=instance_db_log,
)
if actually_start:
await self.inst_db.start()
else:
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
async def stop_database(self) -> None:
if isinstance(self.inst_db, Database):
await self.inst_db.stop()
elif isinstance(self.inst_db, sql.engine.Engine):
self.inst_db.dispose()
else:
raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}")
async def delete_database(self) -> None:
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
if self.inst_db is None:
await self.start_database(None, actually_start=False)
if isinstance(self.inst_db, ProxyPostgresDatabase):
await self.inst_db.delete()
else:
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
else:
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
self.inst_db = None
async def start(self) -> None:
if self.started:
self.log.warning("Ignoring start() call to already started plugin")
@ -196,9 +245,8 @@ class PluginInstance(DBInstance):
elif not self.loader.meta.webapp and self.inst_webapp is not None:
self.log.debug("Disabling webapp after plugin meta reload")
self.disable_webapp()
if self.loader.meta.database and self.inst_db is None:
self.log.debug("Enabling database after plugin meta reload")
self.enable_database()
if self.loader.meta.database:
await self.start_database(cls.get_db_upgrade_table())
config_class = cls.get_config_class()
if config_class:
try:
@ -254,6 +302,11 @@ class PluginInstance(DBInstance):
except Exception:
self.log.exception("Failed to stop instance")
self.plugin = None
if self.inst_db:
try:
await self.stop_database()
except Exception:
self.log.exception("Failed to stop instance database")
self.inst_db_tables = None
async def update_id(self, new_id: str | None) -> None:

94
maubot/lib/plugin_db.py Normal file
View file

@ -0,0 +1,94 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from contextlib import asynccontextmanager
import asyncio
from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable
from mautrix.util.async_db.connection import LoggingConnection
from mautrix.util.logging import TraceLogger
remove_double_quotes = str.maketrans({'"': "_"})
class ProxyPostgresDatabase(Database):
scheme = Scheme.POSTGRES
_underlying_pool: PostgresDatabase
_schema: str
_default_search_path: str
_conn_sema: asyncio.Semaphore
def __init__(
self,
pool: PostgresDatabase,
instance_id: str,
max_conns: int,
upgrade_table: UpgradeTable | None,
log: TraceLogger | None = None,
) -> None:
super().__init__(pool.url, upgrade_table=upgrade_table, log=log)
self._underlying_pool = pool
# Simple accidental SQL injection prevention.
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"'
self._default_search_path = '"$user", public'
self._conn_sema = asyncio.BoundedSemaphore(max_conns)
async def start(self) -> None:
async with self._underlying_pool.acquire() as conn:
self._default_search_path = await conn.fetchval("SHOW search_path")
self.log.debug(f"Found default search path: {self._default_search_path}")
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}")
await super().start()
async def stop(self) -> None:
while not self._conn_sema.locked():
try:
await asyncio.wait_for(self._conn_sema.acquire(), timeout=3)
except asyncio.TimeoutError:
self.log.warning(
"Failed to drain plugin database connection pool, "
"the plugin may be leaking database connections"
)
break
async def delete(self) -> None:
self.log.debug(f"Deleting schema {self._schema} and all data in it")
try:
await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE")
except Exception:
self.log.warning("Failed to delete schema", exc_info=True)
@asynccontextmanager
async def acquire(self) -> LoggingConnection:
conn: LoggingConnection
async with self._conn_sema, self._underlying_pool.acquire() as conn:
await conn.execute(f"SET search_path = {self._default_search_path}")
try:
yield conn
finally:
if not conn.wrapped.is_closed():
try:
await conn.execute(f"SET search_path = {self._default_search_path}")
except Exception:
self.log.exception("Error resetting search_path after use")
await conn.wrapped.close()
else:
self.log.debug("Connection was closed after use, not resetting search_path")
__all__ = ["ProxyPostgresDatabase"]

View file

@ -1,2 +1,3 @@
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader, PluginMeta
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader
from .meta import DatabaseType, PluginMeta
from .zip import MaubotZipImportError, ZippedPluginLoader

View file

@ -18,7 +18,13 @@ from typing import List
from attr import dataclass
from packaging.version import InvalidVersion, Version
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
from mautrix.types import (
ExtensibleEnum,
SerializableAttrs,
SerializerError,
deserializer,
serializer,
)
from ..__meta__ import __version__
@ -36,6 +42,11 @@ def deserialize_version(version: str) -> Version:
raise SerializerError("Invalid version") from e
class DatabaseType(ExtensibleEnum):
SQLALCHEMY = "sqlalchemy"
ASYNCPG = "asyncpg"
@dataclass
class PluginMeta(SerializableAttrs):
id: str
@ -45,6 +56,7 @@ class PluginMeta(SerializableAttrs):
maubot: Version = Version(__version__)
database: bool = False
database_type: DatabaseType = DatabaseType.SQLALCHEMY
config: bool = False
webapp: bool = False
license: str = ""

View file

@ -55,6 +55,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
instance.enabled = data.get("enabled", True)
instance.config_str = data.get("config") or ""
await instance.update()
await instance.load()
await instance.start()
return resp.created(instance.to_dict())

View file

@ -23,10 +23,11 @@ from aiohttp import ClientSession
from sqlalchemy.engine.base import Engine
from yarl import URL
if TYPE_CHECKING:
from mautrix.util.config import BaseProxyConfig
from mautrix.util.logging import TraceLogger
from mautrix.util.async_db import Database, UpgradeTable
from mautrix.util.config import BaseProxyConfig
from mautrix.util.logging import TraceLogger
if TYPE_CHECKING:
from .client import MaubotMatrixClient
from .loader import BasePluginLoader
from .plugin_server import PluginWebApp
@ -40,7 +41,7 @@ class Plugin(ABC):
loop: AbstractEventLoop
loader: BasePluginLoader
config: BaseProxyConfig | None
database: Engine | None
database: Engine | Database | None
webapp: PluginWebApp | None
webapp_url: URL | None
@ -124,6 +125,10 @@ class Plugin(ABC):
def get_config_class(cls) -> type[BaseProxyConfig] | None:
return None
@classmethod
def get_db_upgrade_table(cls) -> UpgradeTable | None:
return None
def on_external_config_update(self) -> Awaitable[None] | None:
if self.config:
self.config.load_and_update()

View file

@ -1,9 +1,6 @@
# Format: #/name defines a new extras_require group called name
# Uncommented lines after the group definition insert things into that group.
#/sqlite
aiosqlite>=0.16,<0.18
#/encryption
python-olm>=3,<4
pycryptodome>=3,<4

View file

@ -3,6 +3,7 @@ aiohttp>=3,<4
yarl>=1,<2
SQLAlchemy>=1,<1.4
asyncpg>=0.20,<0.26
aiosqlite>=0.16,<0.18
alembic>=1,<2
commonmark>=0.9,<1
ruamel.yaml>=0.15.35,<0.18