Enhancements for validation of DB urls

We now turn off retries and DB pooling, and make sure to always close the connection
This commit is contained in:
Joseph Schorr 2018-07-17 12:49:53 -04:00
parent d15dcae505
commit 9a40e99a8f
2 changed files with 20 additions and 8 deletions

View file

@ -263,18 +263,27 @@ ensure_under_transaction = CallableProxy()
def validate_database_url(url, db_kwargs, connect_timeout=5):
""" Validates that we can connect to the given database URL, with the given kwargs. Raises
an exception if the validation fails. """
db_kwargs = db_kwargs.copy()
driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout)
try:
driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout, allow_retry=False,
allow_pooling=False)
driver.connect()
finally:
try:
driver.close()
except:
pass
def _wrap_for_retry(driver):
return type('Retrying' + driver.__class__.__name__, (RetryOperationalError, driver), {})
def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT):
def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT,
allow_pooling=True, allow_retry=True):
parsed_url = make_url(url)
if parsed_url.host:
@ -295,7 +304,7 @@ def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT):
drivers = _SCHEME_DRIVERS[parsed_url.drivername]
driver = drivers.driver
if os.getenv('DB_CONNECTION_POOLING', 'false').lower() == 'true':
if allow_pooling and os.getenv('DB_CONNECTION_POOLING', 'false').lower() == 'true':
driver = drivers.pooled_driver
db_kwargs['stale_timeout'] = db_kwargs.get('stale_timeout', None)
db_kwargs['max_connections'] = db_kwargs.get('max_connections', None)
@ -306,8 +315,10 @@ def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT):
db_kwargs.pop('stale_timeout', None)
db_kwargs.pop('max_connections', None)
wrapped_driver = _wrap_for_retry(driver)
return wrapped_driver(parsed_url.database, **db_kwargs)
if allow_retry:
driver = _wrap_for_retry(driver)
return driver(parsed_url.database, **db_kwargs)
def configure(config_object):

View file

@ -1,5 +1,7 @@
import json
from peewee import SQL
from data.database import (Notification, NotificationKind, User, Team, TeamMember, TeamRole,
RepositoryNotification, ExternalNotificationEvent, Repository,
ExternalNotificationMethod, Namespace, db_for_update)
@ -82,7 +84,6 @@ def list_notifications(user, kind_name=None, id_filter=None, include_dismissed=F
elif limit:
query = query.limit(limit)
from peewee import SQL
return query.order_by(SQL('cd desc'))