Catch errors in frontend SQL queries with asyncpg

This commit is contained in:
Tulir Asokan 2022-03-26 17:14:51 +02:00
parent 93352b6e62
commit acf24a6fd5
2 changed files with 21 additions and 1 deletions

View file

@ -18,6 +18,7 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from aiohttp import web from aiohttp import web
from asyncpg import PostgresError
from sqlalchemy import asc, desc, engine, exc from sqlalchemy import asc, desc, engine, exc
from sqlalchemy.engine.result import ResultProxy, RowProxy from sqlalchemy.engine.result import ResultProxy, RowProxy
import aiosqlite import aiosqlite
@ -85,7 +86,10 @@ async def query(request: web.Request) -> web.Response:
if isinstance(instance.inst_db, engine.Engine): if isinstance(instance.inst_db, engine.Engine):
return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict) return _execute_query_sqlalchemy(instance, sql_query, rows_as_dict)
elif isinstance(instance.inst_db, Database): elif isinstance(instance.inst_db, Database):
return await _execute_query_asyncpg(instance, sql_query, rows_as_dict) try:
return await _execute_query_asyncpg(instance, sql_query, rows_as_dict)
except (PostgresError, aiosqlite.Error) as e:
return resp.sql_error(e, sql_query)
else: else:
return resp.unsupported_plugin_database return resp.unsupported_plugin_database

View file

@ -13,10 +13,14 @@
# #
# You should have received a copy of the GNU Affero General Public License # You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations
from http import HTTPStatus from http import HTTPStatus
from aiohttp import web from aiohttp import web
from asyncpg import PostgresError
from sqlalchemy.exc import IntegrityError, OperationalError from sqlalchemy.exc import IntegrityError, OperationalError
import aiosqlite
class _Response: class _Response:
@ -134,6 +138,18 @@ class _Response:
status=HTTPStatus.BAD_REQUEST, status=HTTPStatus.BAD_REQUEST,
) )
@staticmethod
def sql_error(error: PostgresError | aiosqlite.Error, query: str) -> web.Response:
return web.json_response(
{
"ok": False,
"query": query,
"error": str(error),
"errcode": "sql_error",
},
status=HTTPStatus.BAD_REQUEST,
)
@staticmethod @staticmethod
def sql_operational_error(error: OperationalError, query: str) -> web.Response: def sql_operational_error(error: OperationalError, query: str) -> web.Response:
return web.json_response( return web.json_response(