Add support for asyncpg plugin databases

This commit is contained in:
Tulir Asokan 2022-03-26 13:59:49 +02:00
parent 4b234e4d34
commit 4d8e1475e6
12 changed files with 258 additions and 39 deletions

View file

@ -17,8 +17,8 @@ function fixconfig {
fixdefault '.plugin_directories.upload' './plugins' '/data/plugins' fixdefault '.plugin_directories.upload' './plugins' '/data/plugins'
fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins' fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins'
fixdefault '.plugin_directories.trash' './trash' '/data/trash' fixdefault '.plugin_directories.trash' './trash' '/data/trash'
fixdefault '.plugin_directories.db' './plugins' '/data/dbs' fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs'
fixdefault '.plugin_directories.db' './dbs' '/data/dbs' fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs'
fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log' fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log'
# This doesn't need to be configurable # This doesn't need to be configurable
yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml

View file

@ -18,7 +18,7 @@ from __future__ import annotations
import asyncio import asyncio
import sys import sys
from mautrix.util.async_db import Database, DatabaseException from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme
from mautrix.util.program import Program from mautrix.util.program import Program
from .__meta__ import __version__ from .__meta__ import __version__
@ -43,6 +43,7 @@ class Maubot(Program):
server: MaubotServer server: MaubotServer
db: Database db: Database
crypto_db: Database | None crypto_db: Database | None
plugin_postgres_db: PostgresDatabase | None
state_store: PgStateStore state_store: PgStateStore
config_class = Config config_class = Config
@ -71,13 +72,7 @@ class Maubot(Program):
help="Run even if the database contains tables from other programs (like Synapse)", help="Run even if the database contains tables from other programs (like Synapse)",
) )
def prepare(self) -> None: def prepare_db(self) -> None:
super().prepare()
if self.config["api_features.log"]:
self.prepare_log_websocket()
init_zip_loader(self.config)
self.db = Database.create( self.db = Database.create(
self.config["database"], self.config["database"],
upgrade_table=upgrade_table, upgrade_table=upgrade_table,
@ -86,6 +81,7 @@ class Maubot(Program):
ignore_foreign_tables=self.args.ignore_foreign_tables, ignore_foreign_tables=self.args.ignore_foreign_tables,
) )
init_db(self.db) init_db(self.db)
if self.config["crypto_database"] == "default": if self.config["crypto_database"] == "default":
self.crypto_db = self.db self.crypto_db = self.db
else: else:
@ -94,6 +90,40 @@ class Maubot(Program):
upgrade_table=PgCryptoStore.upgrade_table, upgrade_table=PgCryptoStore.upgrade_table,
ignore_foreign_tables=self.args.ignore_foreign_tables, ignore_foreign_tables=self.args.ignore_foreign_tables,
) )
if self.config["plugin_databases.postgres"] == "default":
if self.db.scheme != Scheme.POSTGRES:
self.log.critical(
'Using "default" as the postgres plugin database URL is only allowed if '
"the default database is postgres."
)
sys.exit(24)
assert isinstance(self.db, PostgresDatabase)
self.plugin_postgres_db = self.db
elif self.config["plugin_databases.postgres"]:
plugin_db = Database.create(
self.config["plugin_databases.postgres"],
db_args={
**self.config["database_opts"],
**self.config["plugin_databases.postgres_opts"],
},
)
if plugin_db.scheme != Scheme.POSTGRES:
self.log.critical("The plugin postgres database URL must be a postgres database")
sys.exit(24)
assert isinstance(plugin_db, PostgresDatabase)
self.plugin_postgres_db = plugin_db
else:
self.plugin_postgres_db = None
def prepare(self) -> None:
super().prepare()
if self.config["api_features.log"]:
self.prepare_log_websocket()
init_zip_loader(self.config)
self.prepare_db()
Client.init_cls(self) Client.init_cls(self)
PluginInstance.init_cls(self) PluginInstance.init_cls(self)
management_api = init_mgmt_api(self.config, self.loop) management_api = init_mgmt_api(self.config, self.loop)

View file

@ -42,7 +42,12 @@ class Config(BaseFileConfig):
copy("plugin_directories.upload") copy("plugin_directories.upload")
copy("plugin_directories.load") copy("plugin_directories.load")
copy("plugin_directories.trash") copy("plugin_directories.trash")
copy("plugin_directories.db") if "plugin_directories.db" in self:
base["plugin_databases.sqlite"] = self["plugin_directories.db"]
else:
copy("plugin_databases.sqlite")
copy("plugin_databases.postgres")
copy("plugin_databases.postgres_opts")
copy("server.hostname") copy("server.hostname")
copy("server.port") copy("server.port")
copy("server.public_url") copy("server.public_url")

View file

@ -16,6 +16,7 @@ database_opts:
min_size: 1 min_size: 1
max_size: 10 max_size: 10
# Configuration for storing plugin .mbp files
plugin_directories: plugin_directories:
# The directory where uploaded new plugins should be stored. # The directory where uploaded new plugins should be stored.
upload: ./plugins upload: ./plugins
@ -26,8 +27,27 @@ plugin_directories:
# The directory where old plugin versions and conflicting plugins should be moved. # The directory where old plugin versions and conflicting plugins should be moved.
# Set to "delete" to delete files immediately. # Set to "delete" to delete files immediately.
trash: ./trash trash: ./trash
# The directory where plugin databases should be stored.
db: ./plugins # Configuration for storing plugin databases
plugin_databases:
# The directory where SQLite plugin databases should be stored.
sqlite: ./plugins
# The connection URL for plugin databases. If null, all plugins will get SQLite databases.
# If set, plugins using the new asyncpg interface will get a Postgres connection instead.
# Plugins using the legacy SQLAlchemy interface will always get a SQLite connection.
#
# To use the same connection pool as the default database, set to "default"
# (the default database above must be postgres to do this).
#
# When enabled, maubot will create separate Postgres schemas in the database for each plugin.
# To view schemas in psql, use `\dn`. To view enter and interact with a specific schema,
# use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal
# SQL queries/psql commands.
postgres: null
# Maximum number of connections per plugin instance.
postgres_max_conns_per_plugin: 3
# Overrides for the default database_opts when using a non-"default" postgres connection string.
postgres_opts: {}
server: server:
# The IP and port to listen to. # The IP and port to listen to.

View file

@ -15,7 +15,7 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
from collections import defaultdict from collections import defaultdict
import asyncio import asyncio
import inspect import inspect
@ -28,19 +28,23 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql import sqlalchemy as sql
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable
from mautrix.util.async_getter_lock import async_getter_lock from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.util.logging import TraceLogger
from .client import Client from .client import Client
from .db import Instance as DBInstance from .db import Instance as DBInstance
from .loader import PluginLoader, ZippedPluginLoader from .lib.plugin_db import ProxyPostgresDatabase
from .loader import DatabaseType, PluginLoader, ZippedPluginLoader
from .plugin_base import Plugin from .plugin_base import Plugin
if TYPE_CHECKING: if TYPE_CHECKING:
from .__main__ import Maubot from .__main__ import Maubot
from .server import PluginWebApp from .server import PluginWebApp
log = logging.getLogger("maubot.instance") log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance"))
db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db"))
yaml = YAML() yaml = YAML()
yaml.indent(4) yaml.indent(4)
@ -60,7 +64,7 @@ class PluginInstance(DBInstance):
config: BaseProxyConfig | None config: BaseProxyConfig | None
base_cfg: RecursiveDict[CommentedMap] | None base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None base_cfg_str: str | None
inst_db: sql.engine.Engine | None inst_db: sql.engine.Engine | Database | None
inst_db_tables: dict[str, sql.Table] | None inst_db_tables: dict[str, sql.Table] | None
inst_webapp: PluginWebApp | None inst_webapp: PluginWebApp | None
inst_webapp_url: str | None inst_webapp_url: str | None
@ -130,8 +134,6 @@ class PluginInstance(DBInstance):
self.log.error(f"Failed to get client for user {self.primary_user}") self.log.error(f"Failed to get client for user {self.primary_user}")
await self.update_enabled(False) await self.update_enabled(False)
return False return False
if self.loader.meta.database:
self.enable_database()
if self.loader.meta.webapp: if self.loader.meta.webapp:
self.enable_webapp() self.enable_webapp()
self.log.debug("Plugin instance dependencies loaded") self.log.debug("Plugin instance dependencies loaded")
@ -147,9 +149,9 @@ class PluginInstance(DBInstance):
self.inst_webapp = None self.inst_webapp = None
self.inst_webapp_url = None self.inst_webapp_url = None
def enable_database(self) -> None: @property
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id) def _sqlite_db_path(self) -> str:
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db") return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db")
async def delete(self) -> None: async def delete(self) -> None:
if self.loader is not None: if self.loader is not None:
@ -162,11 +164,8 @@ class PluginInstance(DBInstance):
pass pass
await super().delete() await super().delete()
if self.inst_db: if self.inst_db:
self.inst_db.dispose() await self.stop_database()
ZippedPluginLoader.trash( await self.delete_database()
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
reason="deleted",
)
if self.inst_webapp: if self.inst_webapp:
self.disable_webapp() self.disable_webapp()
@ -178,6 +177,56 @@ class PluginInstance(DBInstance):
yaml.dump(data, buf) yaml.dump(data, buf)
self.config_str = buf.getvalue() self.config_str = buf.getvalue()
async def start_database(
self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True
) -> None:
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
self.inst_db = sql.create_engine(f"sqlite:///{self._sqlite_db_path}")
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
instance_db_log = db_log.getChild(self.id)
# TODO should there be a way to choose between SQLite and Postgres
# for individual instances? Maybe checking the existence of the SQLite file.
if self.maubot.plugin_postgres_db:
self.inst_db = ProxyPostgresDatabase(
pool=self.maubot.plugin_postgres_db,
instance_id=self.id,
max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"],
upgrade_table=upgrade_table,
log=instance_db_log,
)
else:
self.inst_db = Database.create(
f"sqlite:///{self._sqlite_db_path}",
upgrade_table=upgrade_table,
log=instance_db_log,
)
if actually_start:
await self.inst_db.start()
else:
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
async def stop_database(self) -> None:
if isinstance(self.inst_db, Database):
await self.inst_db.stop()
elif isinstance(self.inst_db, sql.engine.Engine):
self.inst_db.dispose()
else:
raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}")
async def delete_database(self) -> None:
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
if self.inst_db is None:
await self.start_database(None, actually_start=False)
if isinstance(self.inst_db, ProxyPostgresDatabase):
await self.inst_db.delete()
else:
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
else:
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
self.inst_db = None
async def start(self) -> None: async def start(self) -> None:
if self.started: if self.started:
self.log.warning("Ignoring start() call to already started plugin") self.log.warning("Ignoring start() call to already started plugin")
@ -196,9 +245,8 @@ class PluginInstance(DBInstance):
elif not self.loader.meta.webapp and self.inst_webapp is not None: elif not self.loader.meta.webapp and self.inst_webapp is not None:
self.log.debug("Disabling webapp after plugin meta reload") self.log.debug("Disabling webapp after plugin meta reload")
self.disable_webapp() self.disable_webapp()
if self.loader.meta.database and self.inst_db is None: if self.loader.meta.database:
self.log.debug("Enabling database after plugin meta reload") await self.start_database(cls.get_db_upgrade_table())
self.enable_database()
config_class = cls.get_config_class() config_class = cls.get_config_class()
if config_class: if config_class:
try: try:
@ -254,6 +302,11 @@ class PluginInstance(DBInstance):
except Exception: except Exception:
self.log.exception("Failed to stop instance") self.log.exception("Failed to stop instance")
self.plugin = None self.plugin = None
if self.inst_db:
try:
await self.stop_database()
except Exception:
self.log.exception("Failed to stop instance database")
self.inst_db_tables = None self.inst_db_tables = None
async def update_id(self, new_id: str | None) -> None: async def update_id(self, new_id: str | None) -> None:

94
maubot/lib/plugin_db.py Normal file
View file

@ -0,0 +1,94 @@
# maubot - A plugin-based Matrix bot system.
# Copyright (C) 2022 Tulir Asokan
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from contextlib import asynccontextmanager
import asyncio
from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable
from mautrix.util.async_db.connection import LoggingConnection
from mautrix.util.logging import TraceLogger
remove_double_quotes = str.maketrans({'"': "_"})
class ProxyPostgresDatabase(Database):
scheme = Scheme.POSTGRES
_underlying_pool: PostgresDatabase
_schema: str
_default_search_path: str
_conn_sema: asyncio.Semaphore
def __init__(
self,
pool: PostgresDatabase,
instance_id: str,
max_conns: int,
upgrade_table: UpgradeTable | None,
log: TraceLogger | None = None,
) -> None:
super().__init__(pool.url, upgrade_table=upgrade_table, log=log)
self._underlying_pool = pool
# Simple accidental SQL injection prevention.
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"'
self._default_search_path = '"$user", public'
self._conn_sema = asyncio.BoundedSemaphore(max_conns)
async def start(self) -> None:
async with self._underlying_pool.acquire() as conn:
self._default_search_path = await conn.fetchval("SHOW search_path")
self.log.debug(f"Found default search path: {self._default_search_path}")
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}")
await super().start()
async def stop(self) -> None:
while not self._conn_sema.locked():
try:
await asyncio.wait_for(self._conn_sema.acquire(), timeout=3)
except asyncio.TimeoutError:
self.log.warning(
"Failed to drain plugin database connection pool, "
"the plugin may be leaking database connections"
)
break
async def delete(self) -> None:
self.log.debug(f"Deleting schema {self._schema} and all data in it")
try:
await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE")
except Exception:
self.log.warning("Failed to delete schema", exc_info=True)
@asynccontextmanager
async def acquire(self) -> LoggingConnection:
conn: LoggingConnection
async with self._conn_sema, self._underlying_pool.acquire() as conn:
await conn.execute(f"SET search_path = {self._default_search_path}")
try:
yield conn
finally:
if not conn.wrapped.is_closed():
try:
await conn.execute(f"SET search_path = {self._default_search_path}")
except Exception:
self.log.exception("Error resetting search_path after use")
await conn.wrapped.close()
else:
self.log.debug("Connection was closed after use, not resetting search_path")
__all__ = ["ProxyPostgresDatabase"]

View file

@ -1,2 +1,3 @@
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader, PluginMeta from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader
from .meta import DatabaseType, PluginMeta
from .zip import MaubotZipImportError, ZippedPluginLoader from .zip import MaubotZipImportError, ZippedPluginLoader

View file

@ -18,7 +18,13 @@ from typing import List
from attr import dataclass from attr import dataclass
from packaging.version import InvalidVersion, Version from packaging.version import InvalidVersion, Version
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer from mautrix.types import (
ExtensibleEnum,
SerializableAttrs,
SerializerError,
deserializer,
serializer,
)
from ..__meta__ import __version__ from ..__meta__ import __version__
@ -36,6 +42,11 @@ def deserialize_version(version: str) -> Version:
raise SerializerError("Invalid version") from e raise SerializerError("Invalid version") from e
class DatabaseType(ExtensibleEnum):
SQLALCHEMY = "sqlalchemy"
ASYNCPG = "asyncpg"
@dataclass @dataclass
class PluginMeta(SerializableAttrs): class PluginMeta(SerializableAttrs):
id: str id: str
@ -45,6 +56,7 @@ class PluginMeta(SerializableAttrs):
maubot: Version = Version(__version__) maubot: Version = Version(__version__)
database: bool = False database: bool = False
database_type: DatabaseType = DatabaseType.SQLALCHEMY
config: bool = False config: bool = False
webapp: bool = False webapp: bool = False
license: str = "" license: str = ""

View file

@ -55,6 +55,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
instance.enabled = data.get("enabled", True) instance.enabled = data.get("enabled", True)
instance.config_str = data.get("config") or "" instance.config_str = data.get("config") or ""
await instance.update() await instance.update()
await instance.load()
await instance.start() await instance.start()
return resp.created(instance.to_dict()) return resp.created(instance.to_dict())

View file

@ -23,10 +23,11 @@ from aiohttp import ClientSession
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from yarl import URL from yarl import URL
if TYPE_CHECKING: from mautrix.util.async_db import Database, UpgradeTable
from mautrix.util.config import BaseProxyConfig from mautrix.util.config import BaseProxyConfig
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
if TYPE_CHECKING:
from .client import MaubotMatrixClient from .client import MaubotMatrixClient
from .loader import BasePluginLoader from .loader import BasePluginLoader
from .plugin_server import PluginWebApp from .plugin_server import PluginWebApp
@ -40,7 +41,7 @@ class Plugin(ABC):
loop: AbstractEventLoop loop: AbstractEventLoop
loader: BasePluginLoader loader: BasePluginLoader
config: BaseProxyConfig | None config: BaseProxyConfig | None
database: Engine | None database: Engine | Database | None
webapp: PluginWebApp | None webapp: PluginWebApp | None
webapp_url: URL | None webapp_url: URL | None
@ -124,6 +125,10 @@ class Plugin(ABC):
def get_config_class(cls) -> type[BaseProxyConfig] | None: def get_config_class(cls) -> type[BaseProxyConfig] | None:
return None return None
@classmethod
def get_db_upgrade_table(cls) -> UpgradeTable | None:
return None
def on_external_config_update(self) -> Awaitable[None] | None: def on_external_config_update(self) -> Awaitable[None] | None:
if self.config: if self.config:
self.config.load_and_update() self.config.load_and_update()

View file

@ -1,9 +1,6 @@
# Format: #/name defines a new extras_require group called name # Format: #/name defines a new extras_require group called name
# Uncommented lines after the group definition insert things into that group. # Uncommented lines after the group definition insert things into that group.
#/sqlite
aiosqlite>=0.16,<0.18
#/encryption #/encryption
python-olm>=3,<4 python-olm>=3,<4
pycryptodome>=3,<4 pycryptodome>=3,<4

View file

@ -3,6 +3,7 @@ aiohttp>=3,<4
yarl>=1,<2 yarl>=1,<2
SQLAlchemy>=1,<1.4 SQLAlchemy>=1,<1.4
asyncpg>=0.20,<0.26 asyncpg>=0.20,<0.26
aiosqlite>=0.16,<0.18
alembic>=1,<2 alembic>=1,<2
commonmark>=0.9,<1 commonmark>=0.9,<1
ruamel.yaml>=0.15.35,<0.18 ruamel.yaml>=0.15.35,<0.18