Enhancements for validation of DB urls

We now turn off retries and DB pooling, and make sure to always close the connection
This commit is contained in:
Joseph Schorr 2018-07-17 12:49:53 -04:00
parent d15dcae505
commit 9a40e99a8f
2 changed files with 20 additions and 8 deletions

View file

@ -263,18 +263,27 @@ ensure_under_transaction = CallableProxy()
def validate_database_url(url, db_kwargs, connect_timeout=5): def validate_database_url(url, db_kwargs, connect_timeout=5):
""" Validates that we can connect to the given database URL, with the given kwargs. Raises
an exception if the validation fails. """
db_kwargs = db_kwargs.copy() db_kwargs = db_kwargs.copy()
driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout) try:
driver.connect() driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout, allow_retry=False,
driver.close() allow_pooling=False)
driver.connect()
finally:
try:
driver.close()
except:
pass
def _wrap_for_retry(driver): def _wrap_for_retry(driver):
return type('Retrying' + driver.__class__.__name__, (RetryOperationalError, driver), {}) return type('Retrying' + driver.__class__.__name__, (RetryOperationalError, driver), {})
def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT): def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT,
allow_pooling=True, allow_retry=True):
parsed_url = make_url(url) parsed_url = make_url(url)
if parsed_url.host: if parsed_url.host:
@ -295,7 +304,7 @@ def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT):
drivers = _SCHEME_DRIVERS[parsed_url.drivername] drivers = _SCHEME_DRIVERS[parsed_url.drivername]
driver = drivers.driver driver = drivers.driver
if os.getenv('DB_CONNECTION_POOLING', 'false').lower() == 'true': if allow_pooling and os.getenv('DB_CONNECTION_POOLING', 'false').lower() == 'true':
driver = drivers.pooled_driver driver = drivers.pooled_driver
db_kwargs['stale_timeout'] = db_kwargs.get('stale_timeout', None) db_kwargs['stale_timeout'] = db_kwargs.get('stale_timeout', None)
db_kwargs['max_connections'] = db_kwargs.get('max_connections', None) db_kwargs['max_connections'] = db_kwargs.get('max_connections', None)
@ -306,8 +315,10 @@ def _db_from_url(url, db_kwargs, connect_timeout=DEFAULT_DB_CONNECT_TIMEOUT):
db_kwargs.pop('stale_timeout', None) db_kwargs.pop('stale_timeout', None)
db_kwargs.pop('max_connections', None) db_kwargs.pop('max_connections', None)
wrapped_driver = _wrap_for_retry(driver) if allow_retry:
return wrapped_driver(parsed_url.database, **db_kwargs) driver = _wrap_for_retry(driver)
return driver(parsed_url.database, **db_kwargs)
def configure(config_object): def configure(config_object):

View file

@ -1,5 +1,7 @@
import json import json
from peewee import SQL
from data.database import (Notification, NotificationKind, User, Team, TeamMember, TeamRole, from data.database import (Notification, NotificationKind, User, Team, TeamMember, TeamRole,
RepositoryNotification, ExternalNotificationEvent, Repository, RepositoryNotification, ExternalNotificationEvent, Repository,
ExternalNotificationMethod, Namespace, db_for_update) ExternalNotificationMethod, Namespace, db_for_update)
@ -82,7 +84,6 @@ def list_notifications(user, kind_name=None, id_filter=None, include_dismissed=F
elif limit: elif limit:
query = query.limit(limit) query = query.limit(limit)
from peewee import SQL
return query.order_by(SQL('cd desc')) return query.order_by(SQL('cd desc'))