Make frontend database viewer work with asyncpg
This commit is contained in:
parent
32688372fe
commit
6e14cbf5dc
5 changed files with 175 additions and 49 deletions
|
@ -28,7 +28,7 @@ from ruamel.yaml.comments import CommentedMap
|
|||
import sqlalchemy as sql
|
||||
|
||||
from mautrix.types import UserID
|
||||
from mautrix.util.async_db import Database, SQLiteDatabase, UpgradeTable
|
||||
from mautrix.util.async_db import Database, Scheme, UpgradeTable
|
||||
from mautrix.util.async_getter_lock import async_getter_lock
|
||||
from mautrix.util.config import BaseProxyConfig, RecursiveDict
|
||||
from mautrix.util.logging import TraceLogger
|
||||
|
@ -65,7 +65,7 @@ class PluginInstance(DBInstance):
|
|||
base_cfg: RecursiveDict[CommentedMap] | None
|
||||
base_cfg_str: str | None
|
||||
inst_db: sql.engine.Engine | Database | None
|
||||
inst_db_tables: dict[str, sql.Table] | None
|
||||
inst_db_tables: dict | None
|
||||
inst_webapp: PluginWebApp | None
|
||||
inst_webapp_url: str | None
|
||||
started: bool
|
||||
|
@ -113,11 +113,99 @@ class PluginInstance(DBInstance):
|
|||
),
|
||||
}
|
||||
|
||||
def get_db_tables(self) -> dict[str, sql.Table]:
|
||||
if not self.inst_db_tables:
|
||||
def _introspect_sqlalchemy(self) -> dict:
|
||||
metadata = sql.MetaData()
|
||||
metadata.reflect(self.inst_db)
|
||||
self.inst_db_tables = metadata.tables
|
||||
return {
|
||||
table.name: {
|
||||
"columns": {
|
||||
column.name: {
|
||||
"type": str(column.type),
|
||||
"unique": column.unique or False,
|
||||
"default": column.default,
|
||||
"nullable": column.nullable,
|
||||
"primary": column.primary_key,
|
||||
}
|
||||
for column in table.columns
|
||||
},
|
||||
}
|
||||
for table in metadata.tables.values()
|
||||
}
|
||||
|
||||
async def _introspect_sqlite(self) -> dict:
|
||||
q = """
|
||||
SELECT
|
||||
m.name AS table_name,
|
||||
p.cid AS col_id,
|
||||
p.name AS column_name,
|
||||
p.type AS data_type,
|
||||
p.pk AS is_primary,
|
||||
p.dflt_value AS column_default,
|
||||
p.[notnull] AS is_nullable
|
||||
FROM sqlite_master m
|
||||
LEFT JOIN pragma_table_info((m.name)) p
|
||||
WHERE m.type = 'table'
|
||||
ORDER BY table_name, col_id
|
||||
"""
|
||||
data = await self.inst_db.fetch(q)
|
||||
tables = defaultdict(lambda: {"columns": {}})
|
||||
for column in data:
|
||||
table_name = column["table_name"]
|
||||
col_name = column["column_name"]
|
||||
tables[table_name]["columns"][col_name] = {
|
||||
"type": column["data_type"],
|
||||
"nullable": bool(column["is_nullable"]),
|
||||
"default": column["column_default"],
|
||||
"primary": bool(column["is_primary"]),
|
||||
# TODO uniqueness?
|
||||
}
|
||||
return tables
|
||||
|
||||
async def _introspect_postgres(self) -> dict:
|
||||
assert isinstance(self.inst_db, ProxyPostgresDatabase)
|
||||
q = """
|
||||
SELECT col.table_name, col.column_name, col.data_type, col.is_nullable, col.column_default,
|
||||
tc.constraint_type
|
||||
FROM information_schema.columns col
|
||||
LEFT JOIN information_schema.constraint_column_usage ccu
|
||||
ON ccu.column_name=col.column_name
|
||||
LEFT JOIN information_schema.table_constraints tc
|
||||
ON col.table_name=tc.table_name
|
||||
AND col.table_schema=tc.table_schema
|
||||
AND ccu.constraint_name=tc.constraint_name
|
||||
AND ccu.constraint_schema=tc.constraint_schema
|
||||
AND tc.constraint_type IN ('PRIMARY KEY', 'UNIQUE')
|
||||
WHERE col.table_schema=$1
|
||||
"""
|
||||
data = await self.inst_db.fetch(q, self.inst_db.schema_name)
|
||||
tables = defaultdict(lambda: {"columns": {}})
|
||||
for column in data:
|
||||
table_name = column["table_name"]
|
||||
col_name = column["column_name"]
|
||||
tables[table_name]["columns"].setdefault(
|
||||
col_name,
|
||||
{
|
||||
"type": column["data_type"],
|
||||
"nullable": column["is_nullable"],
|
||||
"default": column["column_default"],
|
||||
"primary": False,
|
||||
"unique": False,
|
||||
},
|
||||
)
|
||||
if column["constraint_type"] == "PRIMARY KEY":
|
||||
tables[table_name]["columns"][col_name]["primary"] = True
|
||||
elif column["constraint_type"] == "UNIQUE":
|
||||
tables[table_name]["columns"][col_name]["unique"] = True
|
||||
return tables
|
||||
|
||||
async def get_db_tables(self) -> dict:
|
||||
if self.inst_db_tables is None:
|
||||
if isinstance(self.inst_db, sql.engine.Engine):
|
||||
self.inst_db_tables = self._introspect_sqlalchemy()
|
||||
elif self.inst_db.scheme == Scheme.SQLITE:
|
||||
self.inst_db_tables = await self._introspect_sqlite()
|
||||
else:
|
||||
self.inst_db_tables = await self._introspect_postgres()
|
||||
return self.inst_db_tables
|
||||
|
||||
async def load(self) -> bool:
|
||||
|
|
|
@ -28,7 +28,8 @@ remove_double_quotes = str.maketrans({'"': "_"})
|
|||
class ProxyPostgresDatabase(Database):
|
||||
scheme = Scheme.POSTGRES
|
||||
_underlying_pool: PostgresDatabase
|
||||
_schema: str
|
||||
schema_name: str
|
||||
_quoted_schema: str
|
||||
_default_search_path: str
|
||||
_conn_sema: asyncio.Semaphore
|
||||
|
||||
|
@ -44,7 +45,8 @@ class ProxyPostgresDatabase(Database):
|
|||
self._underlying_pool = pool
|
||||
# Simple accidental SQL injection prevention.
|
||||
# Doesn't have to be perfect, since plugin instance IDs can only be set by admins anyway.
|
||||
self._schema = f'"mbp_{instance_id.translate(remove_double_quotes)}"'
|
||||
self.schema_name = f"mbp_{instance_id.translate(remove_double_quotes)}"
|
||||
self._quoted_schema = f'"{self.schema_name}"'
|
||||
self._default_search_path = '"$user", public'
|
||||
self._conn_sema = asyncio.BoundedSemaphore(max_conns)
|
||||
|
||||
|
@ -52,7 +54,7 @@ class ProxyPostgresDatabase(Database):
|
|||
async with self._underlying_pool.acquire() as conn:
|
||||
self._default_search_path = await conn.fetchval("SHOW search_path")
|
||||
self.log.debug(f"Found default search path: {self._default_search_path}")
|
||||
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._schema}")
|
||||
await conn.execute(f"CREATE SCHEMA IF NOT EXISTS {self._quoted_schema}")
|
||||
await super().start()
|
||||
|
||||
async def stop(self) -> None:
|
||||
|
@ -67,9 +69,11 @@ class ProxyPostgresDatabase(Database):
|
|||
break
|
||||
|
||||
async def delete(self) -> None:
|
||||
self.log.debug(f"Deleting schema {self._schema} and all data in it")
|
||||
self.log.debug(f"Deleting schema {self._quoted_schema} and all data in it")
|
||||
try:
|
||||
await self._underlying_pool.execute(f"DROP SCHEMA IF EXISTS {self._schema} CASCADE")
|
||||
await self._underlying_pool.execute(
|
||||
f"DROP SCHEMA IF EXISTS {self._quoted_schema} CASCADE"
|
||||
)
|
||||
except Exception:
|
||||
self.log.warning("Failed to delete schema", exc_info=True)
|
||||
|
||||
|
@ -77,7 +81,7 @@ class ProxyPostgresDatabase(Database):
|
|||
async def acquire(self) -> LoggingConnection:
|
||||
conn: LoggingConnection
|
||||
async with self._conn_sema, self._underlying_pool.acquire() as conn:
|
||||
await conn.execute(f"SET search_path = {self._default_search_path}")
|
||||
await conn.execute(f"SET search_path = {self._quoted_schema}")
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
|
|
|
@ -18,9 +18,11 @@ from __future__ import annotations
|
|||
from datetime import datetime
|
||||
|
||||
from aiohttp import web
|
||||
from sqlalchemy import Column, Table, asc, desc, exc
|
||||
from sqlalchemy import asc, desc, engine, exc
|
||||
from sqlalchemy.engine.result import ResultProxy, RowProxy
|
||||
from sqlalchemy.orm import Query
|
||||
import aiosqlite
|
||||
|
||||
from mautrix.util.async_db import Database
|
||||
|
||||
from ...instance import PluginInstance
|
||||
from .base import routes
|
||||
|
@ -35,32 +37,7 @@ async def get_database(request: web.Request) -> web.Response:
|
|||
return resp.instance_not_found
|
||||
elif not instance.inst_db:
|
||||
return resp.plugin_has_no_database
|
||||
table: Table
|
||||
column: Column
|
||||
return web.json_response(
|
||||
{
|
||||
table.name: {
|
||||
"columns": {
|
||||
column.name: {
|
||||
"type": str(column.type),
|
||||
"unique": column.unique or False,
|
||||
"default": column.default,
|
||||
"nullable": column.nullable,
|
||||
"primary": column.primary_key,
|
||||
"autoincrement": column.autoincrement,
|
||||
}
|
||||
for column in table.columns
|
||||
},
|
||||
}
|
||||
for table in instance.get_db_tables().values()
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def check_type(val):
|
||||
if isinstance(val, datetime):
|
||||
return val.isoformat()
|
||||
return val
|
||||
return web.json_response(await instance.get_db_tables())
|
||||
|
||||
|
||||
@routes.get("/instance/{id}/database/{table}")
|
||||
|
@ -71,7 +48,7 @@ async def get_table(request: web.Request) -> web.Response:
|
|||
return resp.instance_not_found
|
||||
elif not instance.inst_db:
|
||||
return resp.plugin_has_no_database
|
||||
tables = instance.get_db_tables()
|
||||
tables = await instance.get_db_tables()
|
||||
try:
|
||||
table = tables[request.match_info.get("table", "")]
|
||||
except KeyError:
|
||||
|
@ -87,7 +64,8 @@ async def get_table(request: web.Request) -> web.Response:
|
|||
except KeyError:
|
||||
order = []
|
||||
limit = int(request.query.get("limit", "100"))
|
||||
return execute_query(instance, table.select().order_by(*order).limit(limit))
|
||||
if isinstance(instance.inst_db, engine.Engine):
|
||||
return _execute_query_sqlalchemy(instance, table.select().order_by(*order).limit(limit))
|
||||
|
||||
|
||||
@routes.post("/instance/{id}/database/query")
|
||||
|
@ -103,12 +81,54 @@ async def query(request: web.Request) -> web.Response:
|
|||
sql_query = data["query"]
|
||||
except KeyError:
|
||||
return resp.query_missing
|
||||
return execute_query(instance, sql_query, rows_as_dict=data.get("rows_as_dict", False))
|
||||
rows_as_dict = data.get("rows_as_dict", False)
|
||||
if isinstance(instance.inst_db, engine.Engine):
|
||||
return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict)
|
||||
elif isinstance(instance.inst_db, Database):
|
||||
return await _execute_query_asyncpg(instance, sql_query, rows_as_dict)
|
||||
else:
|
||||
return resp.unsupported_plugin_database
|
||||
|
||||
|
||||
def execute_query(
|
||||
instance: PluginInstance, sql_query: str | Query, rows_as_dict: bool = False
|
||||
def check_type(val):
|
||||
if isinstance(val, datetime):
|
||||
return val.isoformat()
|
||||
return val
|
||||
|
||||
|
||||
async def _execute_query_asyncpg(
|
||||
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
|
||||
) -> web.Response:
|
||||
data = {"ok": True, "query": sql_query}
|
||||
if sql_query.upper().startswith("SELECT"):
|
||||
res = await instance.inst_db.fetch(sql_query)
|
||||
data["rows"] = [
|
||||
(
|
||||
{key: check_type(value) for key, value in row.items()}
|
||||
if rows_as_dict
|
||||
else [check_type(value) for value in row]
|
||||
)
|
||||
for row in res
|
||||
]
|
||||
if len(res) > 0:
|
||||
# TODO can we find column names when there are no rows?
|
||||
data["columns"] = list(res[0].keys())
|
||||
else:
|
||||
res = await instance.inst_db.execute(sql_query)
|
||||
if isinstance(res, str):
|
||||
data["status_msg"] = res
|
||||
elif isinstance(res, aiosqlite.Cursor):
|
||||
data["rowcount"] = res.rowcount
|
||||
# data["inserted_primary_key"] = res.lastrowid
|
||||
else:
|
||||
data["status_msg"] = "unknown status"
|
||||
return web.json_response(data)
|
||||
|
||||
|
||||
def _execute_query_sqlalchemy(
|
||||
instance: PluginInstance, sql_query: str, rows_as_dict: bool = False
|
||||
) -> web.Response:
|
||||
assert isinstance(instance.inst_db, engine.Engine)
|
||||
try:
|
||||
res: ResultProxy = instance.inst_db.execute(sql_query)
|
||||
except exc.IntegrityError as e:
|
||||
|
|
|
@ -299,6 +299,15 @@ class _Response:
|
|||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def unsupported_plugin_database(self) -> web.Response:
|
||||
return web.json_response(
|
||||
{
|
||||
"error": "The database type is not supported by this API",
|
||||
"errcode": "unsupported_plugin_database",
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def table_not_found(self) -> web.Response:
|
||||
return web.json_response(
|
||||
|
|
|
@ -44,6 +44,7 @@ class InstanceDatabase extends Component {
|
|||
error: null,
|
||||
|
||||
prevQuery: null,
|
||||
statusMsg: null,
|
||||
rowCount: null,
|
||||
insertedPrimaryKey: null,
|
||||
}
|
||||
|
@ -111,6 +112,7 @@ class InstanceDatabase extends Component {
|
|||
prevQuery: null,
|
||||
rowCount: null,
|
||||
insertedPrimaryKey: null,
|
||||
statusMsg: null,
|
||||
error: null,
|
||||
})
|
||||
}
|
||||
|
@ -127,7 +129,8 @@ class InstanceDatabase extends Component {
|
|||
this.setState({
|
||||
prevQuery: res.query,
|
||||
rowCount: res.rowcount,
|
||||
insertedPrimaryKey: res.insertedPrimaryKey,
|
||||
insertedPrimaryKey: res.inserted_primary_key,
|
||||
statusMsg: res.status_msg,
|
||||
})
|
||||
this.buildSQLQuery(this.state.selectedTable, false)
|
||||
}
|
||||
|
@ -298,8 +301,10 @@ class InstanceDatabase extends Component {
|
|||
</div>}
|
||||
{this.state.prevQuery && <div className="prev-query">
|
||||
<p>
|
||||
Executed <span className="query">{this.state.prevQuery}</span> -
|
||||
affected <strong>{this.state.rowCount} rows</strong>.
|
||||
Executed <span className="query">{this.state.prevQuery}</span> - {
|
||||
this.state.statusMsg
|
||||
|| <>affected <strong>{this.state.rowCount} rows</strong>.</>
|
||||
}
|
||||
</p>
|
||||
{this.state.insertedPrimaryKey && <p className="inserted-primary-key">
|
||||
Inserted primary key: {this.state.insertedPrimaryKey}
|
||||
|
|
Loading…
Reference in a new issue