Make frontend database viewer work with asyncpg

This commit is contained in:
Tulir Asokan 2022-03-26 16:54:16 +02:00
parent 32688372fe
commit 6e14cbf5dc
5 changed files with 175 additions and 49 deletions

View file

@ -28,7 +28,7 @@ from ruamel.yaml.comments import CommentedMap
import sqlalchemy as sql
from mautrix.types import UserID
from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable
from mautrix.util.async_db import Database, Scheme, UpgradeTable
from mautrix.util.async_getter_lock import async_getter_lock
from mautrix.util.config import BaseProxyConfig, RecursiveDict
from mautrix.util.logging import TraceLogger
@ -65,7 +65,7 @@ class PluginInstance(DBInstance):
base_cfg: RecursiveDict[CommentedMap] | None
base_cfg_str: str | None
inst_db: sql.engine.Engine | Database | None
inst_db_tables: dict[str, sql.Table] | None
inst_db_tables: dict | None
inst_webapp: PluginWebApp | None
inst_webapp_url: str | None
started: bool
@ -113,11 +113,99 @@ class PluginInstance(DBInstance):
),
}
def get_db_tables(self) -> dict[str, sql.Table]:
if not self.inst_db_tables:
def _introspect_sqlalchemy(self) -> dict:
metadata = sql.MetaData()
metadata.reflect(self.inst_db)
self.inst_db_tables = metadata.tables
return {
table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
}
for column in table.columns
},
}
for table in metadata.tables.values()
}
async def _introspect_sqlite(self) -> dict:
q = """
SELECT
m.name AS table_name,
p.cid AS col_id,
p.name AS column_name,
p.type AS data_type,
p.pk AS is_primary,
p.dflt_value AS column_default,
p.[notnull] AS is_nullable
FROM sqlite_master m
LEFT JOIN pragma_table_info((m.name)) p
WHERE m.type = 'table'
ORDER BY table_name, col_id
"""
data = await self.inst_db.fetch(q)
tables = defaultdict(lambda: {"columns": {}})
for column in data:
table_name = column["table_name"]
col_name = column["column_name"]
tables[table_name]["columns"][col_name] = {
"type": column["data_type"],
"nullable": bool(column["is_nullable"]),
"default": column["column_default"],
"primary": bool(column["is_primary"]),
# TODO uniqueness?
}
return tables
async def _introspect_postgres(self) -> dict:
assert isinstance(self.inst_db, ProxyPostgresDatabase)
q = """
SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default,
tc.constraint_type
FROM information_schema.columns col
LEFT JOIN information_schema.constraint_column_usage ccu
ON ccu.column_name=col.column_name
LEFT JOIN information_schema.table_constraints tc
ON col.table_name=tc.table_name
AND col.table_schema=tc.table_schema
AND ccu.constraint_name=tc.constraint_name
AND ccu.constraint_schema=tc.constraint_schema
AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
WHERE col.table_schema=$1
"""
data = await self.inst_db.fetch(q, self.inst_db.schema_name)
tables = defaultdict(lambda: {"columns": {}})
for column in data:
table_name = column["table_name"]
col_name = column["column_name"]
tables[table_name]["columns"].setdefault(
col_name,
{
"type": column["data_type"],
"nullable": column["is_nullable"],
"default": column["column_default"],
"primary": False,
"unique": False,
},
)
if column["constraint_type"] == "PRIMARY KEY":
tables[table_name]["columns"][col_name]["primary"] = True
elif column["constraint_type"] == "UNIQUE":
tables[table_name]["columns"][col_name]["unique"] = True
return tables
async def get_db_tables(self) -> dict:
if self.inst_db_tables is None:
if isinstance(self.inst_db, sql.engine.Engine):
self.inst_db_tables = self._introspect_sqlalchemy()
elif self.inst_db.scheme == Scheme.SQLITE:
self.inst_db_tables = await self._introspect_sqlite()
else:
self.inst_db_tables = await self._introspect_postgres()
return self.inst_db_tables
async def load(self) -> bool:

View file

@ -28,7 +28,8 @@ remove_double_quotes = str.maketrans({'"': "_"})
class ProxyPostgresDatabase(Database):
scheme = Scheme.POSTGRES
_underlying_pool: PostgresDatabase
_schema: str
schema_name: str
_quoted_schema: str
_default_search_path: str
_conn_sema: asyncio.Semaphore
@ -44,7 +45,8 @@ class ProxyPostgresDatabase(Database):
self._underlying_pool = pool
# Simple accidental SQL injection prevention.
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"'
self.schema_name = f"mbp_{instance_id.translate(remove_double_quotes)}"
self._quoted_schema = f'"{self.schema_name}"'
self._default_search_path = '"$user", public'
self._conn_sema = asyncio.BoundedSemaphore(max_conns)
@ -52,7 +54,7 @@ class ProxyPostgresDatabase(Database):
async with self._underlying_pool.acquire() as conn:
self._default_search_path = await conn.fetchval("SHOW search_path")
self.log.debug(f"Found default search path: {self._default_search_path}")
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}")
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._quoted_schema}")
await super().start()
async def stop(self) -> None:
@ -67,9 +69,11 @@ class ProxyPostgresDatabase(Database):
break
async def delete(self) -> None:
self.log.debug(f"Deleting schema {self._schema} and all data in it")
self.log.debug(f"Deleting schema {self._quoted_schema} and all data in it")
try:
await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE")
await self._underlying_pool.execute(
f"DROP SCHEMA IF EXISTS {self._quoted_schema} CASCADE"
)
except Exception:
self.log.warning("Failed to delete schema", exc_info=True)
@ -77,7 +81,7 @@ class ProxyPostgresDatabase(Database):
async def acquire(self) -> LoggingConnection:
conn: LoggingConnection
async with self._conn_sema, self._underlying_pool.acquire() as conn:
await conn.execute(f"SET search_path = {self._default_search_path}")
await conn.execute(f"SET search_path = {self._quoted_schema}")
try:
yield conn
finally:

View file

@ -18,9 +18,11 @@ from __future__ import annotations
from datetime import datetime
from aiohttp import web
from sqlalchemy import Column, Table, asc, desc, exc
from sqlalchemy import asc, desc, engine, exc
from sqlalchemy.engine.result import ResultProxy, RowProxy
from sqlalchemy.orm import Query
import aiosqlite
from mautrix.util.async_db import Database
from ...instance import PluginInstance
from .base import routes
@ -35,32 +37,7 @@ async def get_database(request: web.Request) -> web.Response:
return resp.instance_not_found
elif not instance.inst_db:
return resp.plugin_has_no_database
table: Table
column: Column
return web.json_response(
{
table.name: {
"columns": {
column.name: {
"type": str(column.type),
"unique": column.unique or False,
"default": column.default,
"nullable": column.nullable,
"primary": column.primary_key,
"autoincrement": column.autoincrement,
}
for column in table.columns
},
}
for table in instance.get_db_tables().values()
}
)
def check_type(val):
if isinstance(val, datetime):
return val.isoformat()
return val
return web.json_response(await instance.get_db_tables())
@routes.get("/instance/{id}/database/{table}")
@ -71,7 +48,7 @@ async def get_table(request: web.Request) -> web.Response:
return resp.instance_not_found
elif not instance.inst_db:
return resp.plugin_has_no_database
tables = instance.get_db_tables()
tables = await instance.get_db_tables()
try:
table = tables[request.match_info.get("table", "")]
except KeyError:
@ -87,7 +64,8 @@ async def get_table(request: web.Request) -> web.Response:
except KeyError:
order = []
limit = int(request.query.get("limit", "100"))
return execute_query(instance, table.select().order_by(*order).limit(limit))
if isinstance(instance.inst_db, engine.Engine):
return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit))
@routes.post("/instance/{id}/database/query")
@ -103,12 +81,54 @@ async def query(request: web.Request) -> web.Response:
sql_query = data["query"]
except KeyError:
return resp.query_missing
return execute_query(instance, sql_query, rows_as_dict=data.get("rows_as_dict", False))
rows_as_dict = data.get("rows_as_dict", False)
if isinstance(instance.inst_db, engine.Engine):
return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict)
elif isinstance(instance.inst_db, Database):
return await _execute_query_asyncpg(instance, sql_query, rows_as_dict)
else:
return resp.unsupported_plugin_database
def execute_query(
instance: PluginInstance, sql_query: str | Query, rows_as_dict: bool = False
def check_type(val):
if isinstance(val, datetime):
return val.isoformat()
return val
async def _execute_query_asyncpg(
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
) -> web.Response:
data = {"ok": True, "query": sql_query}
if sql_query.upper().startswith("SELECT"):
res = await instance.inst_db.fetch(sql_query)
data["rows"] = [
(
{key: check_type(value) for key, value in row.items()}
if rows_as_dict
else [check_type(value) for value in row]
)
for row in res
]
if len(res) > 0:
# TODO can we find column names when there are no rows?
data["columns"] = list(res[0].keys())
else:
res = await instance.inst_db.execute(sql_query)
if isinstance(res, str):
data["status_msg"] = res
elif isinstance(res, aiosqlite.Cursor):
data["rowcount"] = res.rowcount
# data["inserted_primary_key"] = res.lastrowid
else:
data["status_msg"] = "unknown status"
return web.json_response(data)
def _execute_query_sqlalchemy(
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
) -> web.Response:
assert isinstance(instance.inst_db, engine.Engine)
try:
res: ResultProxy = instance.inst_db.execute(sql_query)
except exc.IntegrityError as e:

View file

@ -299,6 +299,15 @@ class _Response:
}
)
@property
def unsupported_plugin_database(self) -> web.Response:
return web.json_response(
{
"error": "The database type is not supported by this API",
"errcode": "unsupported_plugin_database",
}
)
@property
def table_not_found(self) -> web.Response:
return web.json_response(

View file

@ -44,6 +44,7 @@ class InstanceDatabase extends Component {
error: null,
prevQuery: null,
statusMsg: null,
rowCount: null,
insertedPrimaryKey: null,
}
@ -111,6 +112,7 @@ class InstanceDatabase extends Component {
prevQuery: null,
rowCount: null,
insertedPrimaryKey: null,
statusMsg: null,
error: null,
})
}
@ -127,7 +129,8 @@ class InstanceDatabase extends Component {
this.setState({
prevQuery: res.query,
rowCount: res.rowcount,
insertedPrimaryKey: res.insertedPrimaryKey,
insertedPrimaryKey: res.inserted_primary_key,
statusMsg: res.status_msg,
})
this.buildSQLQuery(this.state.selectedTable, false)
}
@ -298,8 +301,10 @@ class InstanceDatabase extends Component {
</div>}
{this.state.prevQuery && <div className="prev-query">
<p>
Executed <span className="query">{this.state.prevQuery}</span> -
affected <strong>{this.state.rowCount} rows</strong>.
Executed <span className="query">{this.state.prevQuery}</span> - {
this.state.statusMsg
|| <>affected <strong>{this.state.rowCount} rows</strong>.</>
}
</p>
{this.state.insertedPrimaryKey && <p className="inserted-primary-key">
Inserted primary key: {this.state.insertedPrimaryKey}