Make frontend database viewer work with asyncpg

This commit is contained in:
Tulir Asokan 2022-03-26 16:54:16 +02:00
parent 32688372fe
commit 6e14cbf5dc
5 changed files with 175 additions and 49 deletions

View file

@ -28,7 +28,7 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql import sqlalchemy as sql
from mautrix.types import UserID from mautrix.types import UserID
from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable from mautrix.util.async_db import Database, Scheme, UpgradeTable
from mautrix.util.async_getter_lock import async_getter_lock from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.util.logging import TraceLogger from mautrix.util.logging import TraceLogger
@ -65,7 +65,7 @@ class PluginInstance(DBInstance):
base_cfg: RecursiveDict[CommentedMap] | None base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None base_cfg_str: str | None
inst_db: sql.engine.Engine | Database | None inst_db: sql.engine.Engine | Database | None
inst_db_tables: dict[str, sql.Table] | None inst_db_tables: dict | None
inst_webapp: PluginWebApp | None inst_webapp: PluginWebApp | None
inst_webapp_url: str | None inst_webapp_url: str | None
started: bool started: bool
@ -113,11 +113,99 @@ class PluginInstance(DBInstance):
), ),
} }
def get_db_tables(self) -> dict[str, sql.Table]: def _introspect_sqlalchemy(self) -> dict:
if not self.inst_db_tables: metadata = sql.MetaData()
metadata = sql.MetaData() metadata.reflect(self.inst_db)
metadata.reflect(self.inst_db) return {
self.inst_db_tables = metadata.tables table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
}
for column in table.columns
},
}
for table in metadata.tables.values()
}
async def _introspect_sqlite(self) -> dict:
q = """
SELECT
m.name AS table_name,
p.cid AS col_id,
p.name AS column_name,
p.type AS data_type,
p.pk AS is_primary,
p.dflt_value AS column_default,
p.[notnull] AS is_nullable
FROM sqlite_master m
LEFT JOIN pragma_table_info((m.name)) p
WHERE m.type = 'table'
ORDER BY table_name, col_id
"""
data = await self.inst_db.fetch(q)
tables = defaultdict(lambda: {"columns": {}})
for column in data:
table_name = column["table_name"]
col_name = column["column_name"]
tables[table_name]["columns"][col_name] = {
"type": column["data_type"],
"nullable": bool(column["is_nullable"]),
"default": column["column_default"],
"primary": bool(column["is_primary"]),
# TODO uniqueness?
}
return tables
async def _introspect_postgres(self) -> dict:
assert isinstance(self.inst_db, ProxyPostgresDatabase)
q = """
SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default,
tc.constraint_type
FROM information_schema.columns col
LEFT JOIN information_schema.constraint_column_usage ccu
ON ccu.column_name=col.column_name
LEFT JOIN information_schema.table_constraints tc
ON col.table_name=tc.table_name
AND col.table_schema=tc.table_schema
AND ccu.constraint_name=tc.constraint_name
AND ccu.constraint_schema=tc.constraint_schema
AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
WHERE col.table_schema=$1
"""
data = await self.inst_db.fetch(q, self.inst_db.schema_name)
tables = defaultdict(lambda: {"columns": {}})
for column in data:
table_name = column["table_name"]
col_name = column["column_name"]
tables[table_name]["columns"].setdefault(
col_name,
{
"type": column["data_type"],
"nullable": column["is_nullable"],
"default": column["column_default"],
"primary": False,
"unique": False,
},
)
if column["constraint_type"] == "PRIMARY KEY":
tables[table_name]["columns"][col_name]["primary"] = True
elif column["constraint_type"] == "UNIQUE":
tables[table_name]["columns"][col_name]["unique"] = True
return tables
async def get_db_tables(self) -> dict:
if self.inst_db_tables is None:
if isinstance(self.inst_db, sql.engine.Engine):
self.inst_db_tables = self._introspect_sqlalchemy()
elif self.inst_db.scheme == Scheme.SQLITE:
self.inst_db_tables = await self._introspect_sqlite()
else:
self.inst_db_tables = await self._introspect_postgres()
return self.inst_db_tables return self.inst_db_tables
async def load(self) -> bool: async def load(self) -> bool:

View file

@ -28,7 +28,8 @@ remove_double_quotes = str.maketrans({'"': "_"})
class ProxyPostgresDatabase(Database): class ProxyPostgresDatabase(Database):
scheme = Scheme.POSTGRES scheme = Scheme.POSTGRES
_underlying_pool: PostgresDatabase _underlying_pool: PostgresDatabase
_schema: str schema_name: str
_quoted_schema: str
_default_search_path: str _default_search_path: str
_conn_sema: asyncio.Semaphore _conn_sema: asyncio.Semaphore
@ -44,7 +45,8 @@ class ProxyPostgresDatabase(Database):
self._underlying_pool = pool self._underlying_pool = pool
# Simple accidental SQL injection prevention. # Simple accidental SQL injection prevention.
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway. # Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"' self.schema_name = f"mbp_{instance_id.translate(remove_double_quotes)}"
self._quoted_schema = f'"{self.schema_name}"'
self._default_search_path = '"$user", public' self._default_search_path = '"$user", public'
self._conn_sema = asyncio.BoundedSemaphore(max_conns) self._conn_sema = asyncio.BoundedSemaphore(max_conns)
@ -52,7 +54,7 @@ class ProxyPostgresDatabase(Database):
async with self._underlying_pool.acquire() as conn: async with self._underlying_pool.acquire() as conn:
self._default_search_path = await conn.fetchval("SHOW search_path") self._default_search_path = await conn.fetchval("SHOW search_path")
self.log.debug(f"Found default search path: {self._default_search_path}") self.log.debug(f"Found default search path: {self._default_search_path}")
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}") await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._quoted_schema}")
await super().start() await super().start()
async def stop(self) -> None: async def stop(self) -> None:
@ -67,9 +69,11 @@ class ProxyPostgresDatabase(Database):
break break
async def delete(self) -> None: async def delete(self) -> None:
self.log.debug(f"Deleting schema {self._schema} and all data in it") self.log.debug(f"Deleting schema {self._quoted_schema} and all data in it")
try: try:
await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE") await self._underlying_pool.execute(
f"DROP SCHEMA IF EXISTS {self._quoted_schema} CASCADE"
)
except Exception: except Exception:
self.log.warning("Failed to delete schema", exc_info=True) self.log.warning("Failed to delete schema", exc_info=True)
@ -77,7 +81,7 @@ class ProxyPostgresDatabase(Database):
async def acquire(self) -> LoggingConnection: async def acquire(self) -> LoggingConnection:
conn: LoggingConnection conn: LoggingConnection
async with self._conn_sema, self._underlying_pool.acquire() as conn: async with self._conn_sema, self._underlying_pool.acquire() as conn:
await conn.execute(f"SET search_path = {self._default_search_path}") await conn.execute(f"SET search_path = {self._quoted_schema}")
try: try:
yield conn yield conn
finally: finally:

View file

@ -18,9 +18,11 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from aiohttp import web from aiohttp import web
from sqlalchemy import Column, Table, asc, desc, exc from sqlalchemy import asc, desc, engine, exc
from sqlalchemy.engine.result import ResultProxy, RowProxy from sqlalchemy.engine.result import ResultProxy, RowProxy
from sqlalchemy.orm import Query import aiosqlite
from mautrix.util.async_db import Database
from ...instance import PluginInstance from ...instance import PluginInstance
from .base import routes from .base import routes
@ -35,32 +37,7 @@ async def get_database(request: web.Request) -> web.Response:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:
return resp.plugin_has_no_database return resp.plugin_has_no_database
table: Table return web.json_response(await instance.get_db_tables())
column: Column
return web.json_response(
{
table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
"autoincrement": column.autoincrement,
}
for column in table.columns
},
}
for table in instance.get_db_tables().values()
}
)
def check_type(val):
if isinstance(val, datetime):
return val.isoformat()
return val
@routes.get("/instance/{id}/database/{table}") @routes.get("/instance/{id}/database/{table}")
@ -71,7 +48,7 @@ async def get_table(request: web.Request) -> web.Response:
return resp.instance_not_found return resp.instance_not_found
elif not instance.inst_db: elif not instance.inst_db:
return resp.plugin_has_no_database return resp.plugin_has_no_database
tables = instance.get_db_tables() tables = await instance.get_db_tables()
try: try:
table = tables[request.match_info.get("table", "")] table = tables[request.match_info.get("table", "")]
except KeyError: except KeyError:
@ -87,7 +64,8 @@ async def get_table(request: web.Request) -> web.Response:
except KeyError: except KeyError:
order = [] order = []
limit = int(request.query.get("limit", "100")) limit = int(request.query.get("limit", "100"))
return execute_query(instance, table.select().order_by(*order).limit(limit)) if isinstance(instance.inst_db, engine.Engine):
return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit))
@routes.post("/instance/{id}/database/query") @routes.post("/instance/{id}/database/query")
@ -103,12 +81,54 @@ async def query(request: web.Request) -> web.Response:
sql_query = data["query"] sql_query = data["query"]
except KeyError: except KeyError:
return resp.query_missing return resp.query_missing
return execute_query(instance, sql_query, rows_as_dict=data.get("rows_as_dict", False)) rows_as_dict = data.get("rows_as_dict", False)
if isinstance(instance.inst_db, engine.Engine):
return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict)
elif isinstance(instance.inst_db, Database):
return await _execute_query_asyncpg(instance, sql_query, rows_as_dict)
else:
return resp.unsupported_plugin_database
def execute_query( def check_type(val):
instance: PluginInstance, sql_query: str | Query, rows_as_dict: bool = False if isinstance(val, datetime):
return val.isoformat()
return val
async def _execute_query_asyncpg(
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
) -> web.Response: ) -> web.Response:
data = {"ok": True, "query": sql_query}
if sql_query.upper().startswith("SELECT"):
res = await instance.inst_db.fetch(sql_query)
data["rows"] = [
(
{key: check_type(value) for key, value in row.items()}
if rows_as_dict
else [check_type(value) for value in row]
)
for row in res
]
if len(res) > 0:
# TODO can we find column names when there are no rows?
data["columns"] = list(res[0].keys())
else:
res = await instance.inst_db.execute(sql_query)
if isinstance(res, str):
data["status_msg"] = res
elif isinstance(res, aiosqlite.Cursor):
data["rowcount"] = res.rowcount
# data["inserted_primary_key"] = res.lastrowid
else:
data["status_msg"] = "unknown status"
return web.json_response(data)
def _execute_query_sqlalchemy(
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
) -> web.Response:
assert isinstance(instance.inst_db, engine.Engine)
try: try:
res: ResultProxy = instance.inst_db.execute(sql_query) res: ResultProxy = instance.inst_db.execute(sql_query)
except exc.IntegrityError as e: except exc.IntegrityError as e:

View file

@ -299,6 +299,15 @@ class _Response:
} }
) )
@property
def unsupported_plugin_database(self) -> web.Response:
return web.json_response(
{
"error": "The database type is not supported by this API",
"errcode": "unsupported_plugin_database",
}
)
@property @property
def table_not_found(self) -> web.Response: def table_not_found(self) -> web.Response:
return web.json_response( return web.json_response(

View file

@ -44,6 +44,7 @@ class InstanceDatabase extends Component {
error: null, error: null,
prevQuery: null, prevQuery: null,
statusMsg: null,
rowCount: null, rowCount: null,
insertedPrimaryKey: null, insertedPrimaryKey: null,
} }
@ -111,6 +112,7 @@ class InstanceDatabase extends Component {
prevQuery: null, prevQuery: null,
rowCount: null, rowCount: null,
insertedPrimaryKey: null, insertedPrimaryKey: null,
statusMsg: null,
error: null, error: null,
}) })
} }
@ -127,7 +129,8 @@ class InstanceDatabase extends Component {
this.setState({ this.setState({
prevQuery: res.query, prevQuery: res.query,
rowCount: res.rowcount, rowCount: res.rowcount,
insertedPrimaryKey: res.insertedPrimaryKey, insertedPrimaryKey: res.inserted_primary_key,
statusMsg: res.status_msg,
}) })
this.buildSQLQuery(this.state.selectedTable, false) this.buildSQLQuery(this.state.selectedTable, false)
} }
@ -298,8 +301,10 @@ class InstanceDatabase extends Component {
</div>} </div>}
{this.state.prevQuery && <div className="prev-query"> {this.state.prevQuery && <div className="prev-query">
<p> <p>
Executed <span className="query">{this.state.prevQuery}</span> - Executed <span className="query">{this.state.prevQuery}</span> - {
affected <strong>{this.state.rowCount} rows</strong>. this.state.statusMsg
|| <>affected <strong>{this.state.rowCount} rows</strong>.</>
}
</p> </p>
{this.state.insertedPrimaryKey && <p className="inserted-primary-key"> {this.state.insertedPrimaryKey && <p className="inserted-primary-key">
Inserted primary key: {this.state.insertedPrimaryKey} Inserted primary key: {this.state.insertedPrimaryKey}