Add precondition check for database validation

This commit is contained in:
Sida Chen 2019-03-07 14:15:38 -05:00
parent 3ecc6574ae
commit 7985167411
3 changed files with 38 additions and 2 deletions

View file

@ -30,6 +30,7 @@ from data.fields import (ResumableSHA256Field, ResumableSHA1Field, JSONField, Ba
from data.text import match_mysql, match_like
from data.read_slave import ReadSlaveModel
from util.names import urn_generator
from util.validation import validate_postgres_precondition
logger = logging.getLogger(__name__)
@ -70,6 +71,12 @@ SCHEME_RANDOM_FUNCTION = {
}
PRECONDITION_VALIDATION = {
'postgresql': validate_postgres_precondition,
'postgresql+psycopg2': validate_postgres_precondition,
}
_EXTRA_ARGS = {
'mysql': dict(charset='utf8mb4'),
'mysql+pymysql': dict(charset='utf8mb4'),
@ -284,6 +291,25 @@ def validate_database_url(url, db_kwargs, connect_timeout=5):
pass
def validate_database_precondition(url, db_kwargs, connect_timeout=5):
""" Validates that we can connect to the given database URL and the database meets our
precondition. Raises an exception if the validation fails. """
db_kwargs = db_kwargs.copy()
try:
driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout, allow_retry=False,
allow_pooling=False)
driver.connect()
pre_condition_check = PRECONDITION_VALIDATION.get(make_url(url).drivername)
if pre_condition_check:
pre_condition_check(driver)
finally:
try:
driver.close()
except:
pass
def _wrap_for_retry(driver):
return type('Retrying' + driver.__class__.__name__, (RetryOperationalError, driver), {})

View file

@ -1,6 +1,6 @@
from peewee import OperationalError
from data.database import validate_database_url
from data.database import validate_database_precondition
from util.config.validators import BaseValidator, ConfigValidationException
class DatabaseValidator(BaseValidator):
@ -12,7 +12,7 @@ class DatabaseValidator(BaseValidator):
config = validator_context.config
try:
validate_database_url(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {}))
validate_database_precondition(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {}))
except OperationalError as ex:
if ex.args and len(ex.args) > 1:
raise ConfigValidationException(ex.args[1])

View file

@ -4,6 +4,7 @@ import json
import anunidecode # Don't listen to pylint's lies. This import is required.
from peewee import OperationalError
INVALID_PASSWORD_MESSAGE = 'Invalid password, password must be at least ' + \
'8 characters and contain no whitespace.'
@ -89,3 +90,12 @@ def is_json(value):
except (TypeError, ValueError):
return False
return False
def validate_postgres_precondition(driver):
cursor = driver.execute_sql("SELECT extname FROM pg_extension", ("public",))
if 'pg_trgm' not in [extname for extname, in cursor.fetchall()]:
raise OperationalError("""
"pg_trgm" extension does not exists in the database.
Please run `CREATE EXTENSION IF NOT EXISTS pg_trgm;` as superuser on this database.
""")