diff --git a/data/database.py b/data/database.py index de7262287..5c418d59c 100644 --- a/data/database.py +++ b/data/database.py @@ -263,18 +263,27 @@ ensure_under_transaction = CallableProxy() def validate_database_url(url, db_kwargs, connect_timeout=5): + """ Validates that we can connect to the given database URL, with the given kwargs. Raises + an exception if the validation fails. """ db_kwargs = db_kwargs.copy() - driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout) - driver.connect() - driver.close() + try: + driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout, allow_retry=False, + allow_pooling=False) + driver.connect() + finally: + try: + driver.close() + except: + pass def _wrap_for_retry(driver): return type('Retrying' + driver.__class__.__name__, (RetryOperationalError, driver), {}) -def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT): +def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT, + allow_pooling=True, allow_retry=True): parsed_url = make_url(url) if parsed_url.host: @@ -295,7 +304,7 @@ def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT): drivers = _SCHEME_DRIVERS[parsed_url.drivername] driver = drivers.driver - if os.getenv('DB_CONNECTION_POOLING', 'false').lower() == 'true': + if allow_pooling and os.getenv('DB_CONNECTION_POOLING', 'false').lower() == 'true': driver = drivers.pooled_driver db_kwargs['stale_timeout'] = db_kwargs.get('stale_timeout', None) db_kwargs['max_connections'] = db_kwargs.get('max_connections', None) @@ -306,8 +315,10 @@ def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT): db_kwargs.pop('stale_timeout', None) db_kwargs.pop('max_connections', None) - wrapped_driver = _wrap_for_retry(driver) - return wrapped_driver(parsed_url.database, **db_kwargs) + if allow_retry: + driver = _wrap_for_retry(driver) + + return driver(parsed_url.database, **db_kwargs) def configure(config_object): diff --git a/data/model/notification.py b/data/model/notification.py index df1f743f0..6fc47967a 100644 --- a/data/model/notification.py +++ b/data/model/notification.py @@ -1,5 +1,7 @@ import json +from peewee import SQL + from data.database import (Notification, NotificationKind, User, Team, TeamMember, TeamRole, RepositoryNotification, ExternalNotificationEvent, Repository, ExternalNotificationMethod, Namespace, db_for_update) @@ -82,7 +84,6 @@ def list_notifications(user, kind_name=None, id_filter=None, include_dismissed=F elif limit: query = query.limit(limit) - from peewee import SQL return query.order_by(SQL('cd desc'))