Add support for asyncpg plugin databases
This commit is contained in:
parent
4b234e4d34
commit
4d8e1475e6
12 changed files with 258 additions and 39 deletions
|
@ -17,8 +17,8 @@ function fixconfig {
|
||||||
fixdefault '.plugin_directories.upload' './plugins' '/data/plugins'
|
fixdefault '.plugin_directories.upload' './plugins' '/data/plugins'
|
||||||
fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins'
|
fixdefault '.plugin_directories.load[0]' './plugins' '/data/plugins'
|
||||||
fixdefault '.plugin_directories.trash' './trash' '/data/trash'
|
fixdefault '.plugin_directories.trash' './trash' '/data/trash'
|
||||||
fixdefault '.plugin_directories.db' './plugins' '/data/dbs'
|
fixdefault '.plugin_databases.sqlite' './plugins' '/data/dbs'
|
||||||
fixdefault '.plugin_directories.db' './dbs' '/data/dbs'
|
fixdefault '.plugin_databases.sqlite' './dbs' '/data/dbs'
|
||||||
fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log'
|
fixdefault '.logging.handlers.file.filename' './maubot.log' '/var/log/maubot.log'
|
||||||
# This doesn't need to be configurable
|
# This doesn't need to be configurable
|
||||||
yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml
|
yq e -i '.server.override_resource_path = "/opt/maubot/frontend"' /data/config.yaml
|
||||||
|
|
|
@ -18,7 +18,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from mautrix.util.async_db import Database, DatabaseException
|
from mautrix.util.async_db import Database, DatabaseException, PostgresDatabase, Scheme
|
||||||
from mautrix.util.program import Program
|
from mautrix.util.program import Program
|
||||||
|
|
||||||
from .__meta__ import __version__
|
from .__meta__ import __version__
|
||||||
|
@ -43,6 +43,7 @@ class Maubot(Program):
|
||||||
server: MaubotServer
|
server: MaubotServer
|
||||||
db: Database
|
db: Database
|
||||||
crypto_db: Database | None
|
crypto_db: Database | None
|
||||||
|
plugin_postgres_db: PostgresDatabase | None
|
||||||
state_store: PgStateStore
|
state_store: PgStateStore
|
||||||
|
|
||||||
config_class = Config
|
config_class = Config
|
||||||
|
@ -71,13 +72,7 @@ class Maubot(Program):
|
||||||
help="Run even if the database contains tables from other programs (like Synapse)",
|
help="Run even if the database contains tables from other programs (like Synapse)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare(self) -> None:
|
def prepare_db(self) -> None:
|
||||||
super().prepare()
|
|
||||||
|
|
||||||
if self.config["api_features.log"]:
|
|
||||||
self.prepare_log_websocket()
|
|
||||||
|
|
||||||
init_zip_loader(self.config)
|
|
||||||
self.db = Database.create(
|
self.db = Database.create(
|
||||||
self.config["database"],
|
self.config["database"],
|
||||||
upgrade_table=upgrade_table,
|
upgrade_table=upgrade_table,
|
||||||
|
@ -86,6 +81,7 @@ class Maubot(Program):
|
||||||
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||||
)
|
)
|
||||||
init_db(self.db)
|
init_db(self.db)
|
||||||
|
|
||||||
if self.config["crypto_database"] == "default":
|
if self.config["crypto_database"] == "default":
|
||||||
self.crypto_db = self.db
|
self.crypto_db = self.db
|
||||||
else:
|
else:
|
||||||
|
@ -94,6 +90,40 @@ class Maubot(Program):
|
||||||
upgrade_table=PgCryptoStore.upgrade_table,
|
upgrade_table=PgCryptoStore.upgrade_table,
|
||||||
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
ignore_foreign_tables=self.args.ignore_foreign_tables,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config["plugin_databases.postgres"] == "default":
|
||||||
|
if self.db.scheme != Scheme.POSTGRES:
|
||||||
|
self.log.critical(
|
||||||
|
'Using "default" as the postgres plugin database URL is only allowed if '
|
||||||
|
"the default database is postgres."
|
||||||
|
)
|
||||||
|
sys.exit(24)
|
||||||
|
assert isinstance(self.db, PostgresDatabase)
|
||||||
|
self.plugin_postgres_db = self.db
|
||||||
|
elif self.config["plugin_databases.postgres"]:
|
||||||
|
plugin_db = Database.create(
|
||||||
|
self.config["plugin_databases.postgres"],
|
||||||
|
db_args={
|
||||||
|
**self.config["database_opts"],
|
||||||
|
**self.config["plugin_databases.postgres_opts"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if plugin_db.scheme != Scheme.POSTGRES:
|
||||||
|
self.log.critical("The plugin postgres database URL must be a postgres database")
|
||||||
|
sys.exit(24)
|
||||||
|
assert isinstance(plugin_db, PostgresDatabase)
|
||||||
|
self.plugin_postgres_db = plugin_db
|
||||||
|
else:
|
||||||
|
self.plugin_postgres_db = None
|
||||||
|
|
||||||
|
def prepare(self) -> None:
|
||||||
|
super().prepare()
|
||||||
|
|
||||||
|
if self.config["api_features.log"]:
|
||||||
|
self.prepare_log_websocket()
|
||||||
|
|
||||||
|
init_zip_loader(self.config)
|
||||||
|
self.prepare_db()
|
||||||
Client.init_cls(self)
|
Client.init_cls(self)
|
||||||
PluginInstance.init_cls(self)
|
PluginInstance.init_cls(self)
|
||||||
management_api = init_mgmt_api(self.config, self.loop)
|
management_api = init_mgmt_api(self.config, self.loop)
|
||||||
|
|
|
@ -42,7 +42,12 @@ class Config(BaseFileConfig):
|
||||||
copy("plugin_directories.upload")
|
copy("plugin_directories.upload")
|
||||||
copy("plugin_directories.load")
|
copy("plugin_directories.load")
|
||||||
copy("plugin_directories.trash")
|
copy("plugin_directories.trash")
|
||||||
copy("plugin_directories.db")
|
if "plugin_directories.db" in self:
|
||||||
|
base["plugin_databases.sqlite"] = self["plugin_directories.db"]
|
||||||
|
else:
|
||||||
|
copy("plugin_databases.sqlite")
|
||||||
|
copy("plugin_databases.postgres")
|
||||||
|
copy("plugin_databases.postgres_opts")
|
||||||
copy("server.hostname")
|
copy("server.hostname")
|
||||||
copy("server.port")
|
copy("server.port")
|
||||||
copy("server.public_url")
|
copy("server.public_url")
|
||||||
|
|
|
@ -16,6 +16,7 @@ database_opts:
|
||||||
min_size: 1
|
min_size: 1
|
||||||
max_size: 10
|
max_size: 10
|
||||||
|
|
||||||
|
# Configuration for storing plugin .mbp files
|
||||||
plugin_directories:
|
plugin_directories:
|
||||||
# The directory where uploaded new plugins should be stored.
|
# The directory where uploaded new plugins should be stored.
|
||||||
upload: ./plugins
|
upload: ./plugins
|
||||||
|
@ -26,8 +27,27 @@ plugin_directories:
|
||||||
# The directory where old plugin versions and conflicting plugins should be moved.
|
# The directory where old plugin versions and conflicting plugins should be moved.
|
||||||
# Set to "delete" to delete files immediately.
|
# Set to "delete" to delete files immediately.
|
||||||
trash: ./trash
|
trash: ./trash
|
||||||
# The directory where plugin databases should be stored.
|
|
||||||
db: ./plugins
|
# Configuration for storing plugin databases
|
||||||
|
plugin_databases:
|
||||||
|
# The directory where SQLite plugin databases should be stored.
|
||||||
|
sqlite: ./plugins
|
||||||
|
# The connection URL for plugin databases. If null, all plugins will get SQLite databases.
|
||||||
|
# If set, plugins using the new asyncpg interface will get a Postgres connection instead.
|
||||||
|
# Plugins using the legacy SQLAlchemy interface will always get a SQLite connection.
|
||||||
|
#
|
||||||
|
# To use the same connection pool as the default database, set to "default"
|
||||||
|
# (the default database above must be postgres to do this).
|
||||||
|
#
|
||||||
|
# When enabled, maubot will create separate Postgres schemas in the database for each plugin.
|
||||||
|
# To view schemas in psql, use `\dn`. To view enter and interact with a specific schema,
|
||||||
|
# use `SET search_path = name` (where `name` is the name found with `\dn`) and then use normal
|
||||||
|
# SQL queries/psql commands.
|
||||||
|
postgres: null
|
||||||
|
# Maximum number of connections per plugin instance.
|
||||||
|
postgres_max_conns_per_plugin: 3
|
||||||
|
# Overrides for the default database_opts when using a non-"default" postgres connection string.
|
||||||
|
postgres_opts: {}
|
||||||
|
|
||||||
server:
|
server:
|
||||||
# The IP and port to listen to.
|
# The IP and port to listen to.
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterable, Awaitable, cast
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, cast
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -28,19 +28,23 @@ from ruamel.yaml.comments import CommentedMap
|
||||||
import sqlalchemy as sql
|
import sqlalchemy as sql
|
||||||
|
|
||||||
from mautrix.types import UserID
|
from mautrix.types import UserID
|
||||||
|
from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable
|
||||||
from mautrix.util.async_getter_lock import async_getter_lock
|
from mautrix.util.async_getter_lock import async_getter_lock
|
||||||
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
||||||
|
from mautrix.util.logging import TraceLogger
|
||||||
|
|
||||||
from .client import Client
|
from .client import Client
|
||||||
from .db import Instance as DBInstance
|
from .db import Instance as DBInstance
|
||||||
from .loader import PluginLoader, ZippedPluginLoader
|
from .lib.plugin_db import ProxyPostgresDatabase
|
||||||
|
from .loader import DatabaseType, PluginLoader, ZippedPluginLoader
|
||||||
from .plugin_base import Plugin
|
from .plugin_base import Plugin
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .__main__ import Maubot
|
from .__main__ import Maubot
|
||||||
from .server import PluginWebApp
|
from .server import PluginWebApp
|
||||||
|
|
||||||
log = logging.getLogger("maubot.instance")
|
log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance"))
|
||||||
|
db_log: TraceLogger = cast(TraceLogger, logging.getLogger("maubot.instance_db"))
|
||||||
|
|
||||||
yaml = YAML()
|
yaml = YAML()
|
||||||
yaml.indent(4)
|
yaml.indent(4)
|
||||||
|
@ -60,7 +64,7 @@ class PluginInstance(DBInstance):
|
||||||
config: BaseProxyConfig | None
|
config: BaseProxyConfig | None
|
||||||
base_cfg: RecursiveDict[CommentedMap] | None
|
base_cfg: RecursiveDict[CommentedMap] | None
|
||||||
base_cfg_str: str | None
|
base_cfg_str: str | None
|
||||||
inst_db: sql.engine.Engine | None
|
inst_db: sql.engine.Engine | Database | None
|
||||||
inst_db_tables: dict[str, sql.Table] | None
|
inst_db_tables: dict[str, sql.Table] | None
|
||||||
inst_webapp: PluginWebApp | None
|
inst_webapp: PluginWebApp | None
|
||||||
inst_webapp_url: str | None
|
inst_webapp_url: str | None
|
||||||
|
@ -130,8 +134,6 @@ class PluginInstance(DBInstance):
|
||||||
self.log.error(f"Failed to get client for user {self.primary_user}")
|
self.log.error(f"Failed to get client for user {self.primary_user}")
|
||||||
await self.update_enabled(False)
|
await self.update_enabled(False)
|
||||||
return False
|
return False
|
||||||
if self.loader.meta.database:
|
|
||||||
self.enable_database()
|
|
||||||
if self.loader.meta.webapp:
|
if self.loader.meta.webapp:
|
||||||
self.enable_webapp()
|
self.enable_webapp()
|
||||||
self.log.debug("Plugin instance dependencies loaded")
|
self.log.debug("Plugin instance dependencies loaded")
|
||||||
|
@ -147,9 +149,9 @@ class PluginInstance(DBInstance):
|
||||||
self.inst_webapp = None
|
self.inst_webapp = None
|
||||||
self.inst_webapp_url = None
|
self.inst_webapp_url = None
|
||||||
|
|
||||||
def enable_database(self) -> None:
|
@property
|
||||||
db_path = os.path.join(self.maubot.config["plugin_directories.db"], self.id)
|
def _sqlite_db_path(self) -> str:
|
||||||
self.inst_db = sql.create_engine(f"sqlite:///{db_path}.db")
|
return os.path.join(self.maubot.config["plugin_databases.sqlite"], f"{self.id}.db")
|
||||||
|
|
||||||
async def delete(self) -> None:
|
async def delete(self) -> None:
|
||||||
if self.loader is not None:
|
if self.loader is not None:
|
||||||
|
@ -162,11 +164,8 @@ class PluginInstance(DBInstance):
|
||||||
pass
|
pass
|
||||||
await super().delete()
|
await super().delete()
|
||||||
if self.inst_db:
|
if self.inst_db:
|
||||||
self.inst_db.dispose()
|
await self.stop_database()
|
||||||
ZippedPluginLoader.trash(
|
await self.delete_database()
|
||||||
os.path.join(self.maubot.config["plugin_directories.db"], f"{self.id}.db"),
|
|
||||||
reason="deleted",
|
|
||||||
)
|
|
||||||
if self.inst_webapp:
|
if self.inst_webapp:
|
||||||
self.disable_webapp()
|
self.disable_webapp()
|
||||||
|
|
||||||
|
@ -178,6 +177,56 @@ class PluginInstance(DBInstance):
|
||||||
yaml.dump(data, buf)
|
yaml.dump(data, buf)
|
||||||
self.config_str = buf.getvalue()
|
self.config_str = buf.getvalue()
|
||||||
|
|
||||||
|
async def start_database(
|
||||||
|
self, upgrade_table: UpgradeTable | None = None, actually_start: bool = True
|
||||||
|
) -> None:
|
||||||
|
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
|
||||||
|
self.inst_db = sql.create_engine(f"sqlite:///{self._sqlite_db_path}")
|
||||||
|
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
|
||||||
|
instance_db_log = db_log.getChild(self.id)
|
||||||
|
# TODO should there be a way to choose between SQLite and Postgres
|
||||||
|
# for individual instances? Maybe checking the existence of the SQLite file.
|
||||||
|
if self.maubot.plugin_postgres_db:
|
||||||
|
self.inst_db = ProxyPostgresDatabase(
|
||||||
|
pool=self.maubot.plugin_postgres_db,
|
||||||
|
instance_id=self.id,
|
||||||
|
max_conns=self.maubot.config["plugin_databases.postgres_max_conns_per_plugin"],
|
||||||
|
upgrade_table=upgrade_table,
|
||||||
|
log=instance_db_log,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.inst_db = Database.create(
|
||||||
|
f"sqlite:///{self._sqlite_db_path}",
|
||||||
|
upgrade_table=upgrade_table,
|
||||||
|
log=instance_db_log,
|
||||||
|
)
|
||||||
|
if actually_start:
|
||||||
|
await self.inst_db.start()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
|
||||||
|
|
||||||
|
async def stop_database(self) -> None:
|
||||||
|
if isinstance(self.inst_db, Database):
|
||||||
|
await self.inst_db.stop()
|
||||||
|
elif isinstance(self.inst_db, sql.engine.Engine):
|
||||||
|
self.inst_db.dispose()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown database type {type(self.inst_db).__name__}")
|
||||||
|
|
||||||
|
async def delete_database(self) -> None:
|
||||||
|
if self.loader.meta.database_type == DatabaseType.SQLALCHEMY:
|
||||||
|
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
|
||||||
|
elif self.loader.meta.database_type == DatabaseType.ASYNCPG:
|
||||||
|
if self.inst_db is None:
|
||||||
|
await self.start_database(None, actually_start=False)
|
||||||
|
if isinstance(self.inst_db, ProxyPostgresDatabase):
|
||||||
|
await self.inst_db.delete()
|
||||||
|
else:
|
||||||
|
ZippedPluginLoader.trash(self._sqlite_db_path, reason="deleted")
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unrecognized database type {self.loader.meta.database_type}")
|
||||||
|
self.inst_db = None
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self.started:
|
if self.started:
|
||||||
self.log.warning("Ignoring start() call to already started plugin")
|
self.log.warning("Ignoring start() call to already started plugin")
|
||||||
|
@ -196,9 +245,8 @@ class PluginInstance(DBInstance):
|
||||||
elif not self.loader.meta.webapp and self.inst_webapp is not None:
|
elif not self.loader.meta.webapp and self.inst_webapp is not None:
|
||||||
self.log.debug("Disabling webapp after plugin meta reload")
|
self.log.debug("Disabling webapp after plugin meta reload")
|
||||||
self.disable_webapp()
|
self.disable_webapp()
|
||||||
if self.loader.meta.database and self.inst_db is None:
|
if self.loader.meta.database:
|
||||||
self.log.debug("Enabling database after plugin meta reload")
|
await self.start_database(cls.get_db_upgrade_table())
|
||||||
self.enable_database()
|
|
||||||
config_class = cls.get_config_class()
|
config_class = cls.get_config_class()
|
||||||
if config_class:
|
if config_class:
|
||||||
try:
|
try:
|
||||||
|
@ -254,6 +302,11 @@ class PluginInstance(DBInstance):
|
||||||
except Exception:
|
except Exception:
|
||||||
self.log.exception("Failed to stop instance")
|
self.log.exception("Failed to stop instance")
|
||||||
self.plugin = None
|
self.plugin = None
|
||||||
|
if self.inst_db:
|
||||||
|
try:
|
||||||
|
await self.stop_database()
|
||||||
|
except Exception:
|
||||||
|
self.log.exception("Failed to stop instance database")
|
||||||
self.inst_db_tables = None
|
self.inst_db_tables = None
|
||||||
|
|
||||||
async def update_id(self, new_id: str | None) -> None:
|
async def update_id(self, new_id: str | None) -> None:
|
||||||
|
|
94
maubot/lib/plugin_db.py
Normal file
94
maubot/lib/plugin_db.py
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
# maubot - A plugin-based Matrix bot system.
|
||||||
|
# Copyright (C) 2022 Tulir Asokan
|
||||||
|
#
|
||||||
|
# This program is free software: you can redistribute it and/or modify
|
||||||
|
# it under the terms of the GNU Affero General Public License as published by
|
||||||
|
# the Free Software Foundation, either version 3 of the License, or
|
||||||
|
# (at your option) any later version.
|
||||||
|
#
|
||||||
|
# This program is distributed in the hope that it will be useful,
|
||||||
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
# GNU Affero General Public License for more details.
|
||||||
|
#
|
||||||
|
# You should have received a copy of the GNU Affero General Public License
|
||||||
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from mautrix.util.async_db import Database, PostgresDatabase, Scheme, UpgradeTable
|
||||||
|
from mautrix.util.async_db.connection import LoggingConnection
|
||||||
|
from mautrix.util.logging import TraceLogger
|
||||||
|
|
||||||
|
remove_double_quotes = str.maketrans({'"': "_"})
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyPostgresDatabase(Database):
|
||||||
|
scheme = Scheme.POSTGRES
|
||||||
|
_underlying_pool: PostgresDatabase
|
||||||
|
_schema: str
|
||||||
|
_default_search_path: str
|
||||||
|
_conn_sema: asyncio.Semaphore
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pool: PostgresDatabase,
|
||||||
|
instance_id: str,
|
||||||
|
max_conns: int,
|
||||||
|
upgrade_table: UpgradeTable | None,
|
||||||
|
log: TraceLogger | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(pool.url, upgrade_table=upgrade_table, log=log)
|
||||||
|
self._underlying_pool = pool
|
||||||
|
# Simple accidental SQL injection prevention.
|
||||||
|
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
|
||||||
|
self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"'
|
||||||
|
self._default_search_path = '"$user", public'
|
||||||
|
self._conn_sema = asyncio.BoundedSemaphore(max_conns)
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
async with self._underlying_pool.acquire() as conn:
|
||||||
|
self._default_search_path = await conn.fetchval("SHOW search_path")
|
||||||
|
self.log.debug(f"Found default search path: {self._default_search_path}")
|
||||||
|
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}")
|
||||||
|
await super().start()
|
||||||
|
|
||||||
|
async def stop(self) -> None:
|
||||||
|
while not self._conn_sema.locked():
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._conn_sema.acquire(), timeout=3)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self.log.warning(
|
||||||
|
"Failed to drain plugin database connection pool, "
|
||||||
|
"the plugin may be leaking database connections"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
async def delete(self) -> None:
|
||||||
|
self.log.debug(f"Deleting schema {self._schema} and all data in it")
|
||||||
|
try:
|
||||||
|
await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE")
|
||||||
|
except Exception:
|
||||||
|
self.log.warning("Failed to delete schema", exc_info=True)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def acquire(self) -> LoggingConnection:
|
||||||
|
conn: LoggingConnection
|
||||||
|
async with self._conn_sema, self._underlying_pool.acquire() as conn:
|
||||||
|
await conn.execute(f"SET search_path = {self._default_search_path}")
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
finally:
|
||||||
|
if not conn.wrapped.is_closed():
|
||||||
|
try:
|
||||||
|
await conn.execute(f"SET search_path = {self._default_search_path}")
|
||||||
|
except Exception:
|
||||||
|
self.log.exception("Error resetting search_path after use")
|
||||||
|
await conn.wrapped.close()
|
||||||
|
else:
|
||||||
|
self.log.debug("Connection was closed after use, not resetting search_path")
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ProxyPostgresDatabase"]
|
|
@ -1,2 +1,3 @@
|
||||||
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader, PluginMeta
|
from .abc import BasePluginLoader, IDConflictError, PluginClass, PluginLoader
|
||||||
|
from .meta import DatabaseType, PluginMeta
|
||||||
from .zip import MaubotZipImportError, ZippedPluginLoader
|
from .zip import MaubotZipImportError, ZippedPluginLoader
|
||||||
|
|
|
@ -18,7 +18,13 @@ from typing import List
|
||||||
from attr import dataclass
|
from attr import dataclass
|
||||||
from packaging.version import InvalidVersion, Version
|
from packaging.version import InvalidVersion, Version
|
||||||
|
|
||||||
from mautrix.types import SerializableAttrs, SerializerError, deserializer, serializer
|
from mautrix.types import (
|
||||||
|
ExtensibleEnum,
|
||||||
|
SerializableAttrs,
|
||||||
|
SerializerError,
|
||||||
|
deserializer,
|
||||||
|
serializer,
|
||||||
|
)
|
||||||
|
|
||||||
from ..__meta__ import __version__
|
from ..__meta__ import __version__
|
||||||
|
|
||||||
|
@ -36,6 +42,11 @@ def deserialize_version(version: str) -> Version:
|
||||||
raise SerializerError("Invalid version") from e
|
raise SerializerError("Invalid version") from e
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseType(ExtensibleEnum):
|
||||||
|
SQLALCHEMY = "sqlalchemy"
|
||||||
|
ASYNCPG = "asyncpg"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PluginMeta(SerializableAttrs):
|
class PluginMeta(SerializableAttrs):
|
||||||
id: str
|
id: str
|
||||||
|
@ -45,6 +56,7 @@ class PluginMeta(SerializableAttrs):
|
||||||
|
|
||||||
maubot: Version = Version(__version__)
|
maubot: Version = Version(__version__)
|
||||||
database: bool = False
|
database: bool = False
|
||||||
|
database_type: DatabaseType = DatabaseType.SQLALCHEMY
|
||||||
config: bool = False
|
config: bool = False
|
||||||
webapp: bool = False
|
webapp: bool = False
|
||||||
license: str = ""
|
license: str = ""
|
||||||
|
|
|
@ -55,6 +55,7 @@ async def _create_instance(instance_id: str, data: dict) -> web.Response:
|
||||||
instance.enabled = data.get("enabled", True)
|
instance.enabled = data.get("enabled", True)
|
||||||
instance.config_str = data.get("config") or ""
|
instance.config_str = data.get("config") or ""
|
||||||
await instance.update()
|
await instance.update()
|
||||||
|
await instance.load()
|
||||||
await instance.start()
|
await instance.start()
|
||||||
return resp.created(instance.to_dict())
|
return resp.created(instance.to_dict())
|
||||||
|
|
||||||
|
|
|
@ -23,10 +23,11 @@ from aiohttp import ClientSession
|
||||||
from sqlalchemy.engine.base import Engine
|
from sqlalchemy.engine.base import Engine
|
||||||
from yarl import URL
|
from yarl import URL
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from mautrix.util.async_db import Database, UpgradeTable
|
||||||
from mautrix.util.config import BaseProxyConfig
|
from mautrix.util.config import BaseProxyConfig
|
||||||
from mautrix.util.logging import TraceLogger
|
from mautrix.util.logging import TraceLogger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
from .client import MaubotMatrixClient
|
from .client import MaubotMatrixClient
|
||||||
from .loader import BasePluginLoader
|
from .loader import BasePluginLoader
|
||||||
from .plugin_server import PluginWebApp
|
from .plugin_server import PluginWebApp
|
||||||
|
@ -40,7 +41,7 @@ class Plugin(ABC):
|
||||||
loop: AbstractEventLoop
|
loop: AbstractEventLoop
|
||||||
loader: BasePluginLoader
|
loader: BasePluginLoader
|
||||||
config: BaseProxyConfig | None
|
config: BaseProxyConfig | None
|
||||||
database: Engine | None
|
database: Engine | Database | None
|
||||||
webapp: PluginWebApp | None
|
webapp: PluginWebApp | None
|
||||||
webapp_url: URL | None
|
webapp_url: URL | None
|
||||||
|
|
||||||
|
@ -124,6 +125,10 @@ class Plugin(ABC):
|
||||||
def get_config_class(cls) -> type[BaseProxyConfig] | None:
|
def get_config_class(cls) -> type[BaseProxyConfig] | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_db_upgrade_table(cls) -> UpgradeTable | None:
|
||||||
|
return None
|
||||||
|
|
||||||
def on_external_config_update(self) -> Awaitable[None] | None:
|
def on_external_config_update(self) -> Awaitable[None] | None:
|
||||||
if self.config:
|
if self.config:
|
||||||
self.config.load_and_update()
|
self.config.load_and_update()
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
# Format: #/name defines a new extras_require group called name
|
# Format: #/name defines a new extras_require group called name
|
||||||
# Uncommented lines after the group definition insert things into that group.
|
# Uncommented lines after the group definition insert things into that group.
|
||||||
|
|
||||||
#/sqlite
|
|
||||||
aiosqlite>=0.16,<0.18
|
|
||||||
|
|
||||||
#/encryption
|
#/encryption
|
||||||
python-olm>=3,<4
|
python-olm>=3,<4
|
||||||
pycryptodome>=3,<4
|
pycryptodome>=3,<4
|
||||||
|
|
|
@ -3,6 +3,7 @@ aiohttp>=3,<4
|
||||||
yarl>=1,<2
|
yarl>=1,<2
|
||||||
SQLAlchemy>=1,<1.4
|
SQLAlchemy>=1,<1.4
|
||||||
asyncpg>=0.20,<0.26
|
asyncpg>=0.20,<0.26
|
||||||
|
aiosqlite>=0.16,<0.18
|
||||||
alembic>=1,<2
|
alembic>=1,<2
|
||||||
commonmark>=0.9,<1
|
commonmark>=0.9,<1
|
||||||
ruamel.yaml>=0.15.35,<0.18
|
ruamel.yaml>=0.15.35,<0.18
|
||||||
|
|
Loading…
Reference in a new issue