diff --git a/data/database.py b/data/database.py index a2fb8d632..09971b0dd 100644 --- a/data/database.py +++ b/data/database.py @@ -30,6 +30,7 @@ from data.fields import (ResumableSHA256Field, ResumableSHA1Field, JSONField, Ba from data.text import match_mysql, match_like from data.read_slave import ReadSlaveModel from util.names import urn_generator +from util.validation import validate_postgres_precondition logger = logging.getLogger(__name__) @@ -70,6 +71,12 @@ SCHEME_RANDOM_FUNCTION = { } +PRECONDITION_VALIDATION = { + 'postgresql': validate_postgres_precondition, + 'postgresql+psycopg2': validate_postgres_precondition, +} + + _EXTRA_ARGS = { 'mysql': dict(charset='utf8mb4'), 'mysql+pymysql': dict(charset='utf8mb4'), @@ -284,6 +291,25 @@ def validate_database_url(url, db_kwargs, connect_timeout=5): pass +def validate_database_precondition(url, db_kwargs, connect_timeout=5): + """ Validates that we can connect to the given database URL and the database meets our + precondition. Raises an exception if the validation fails. """ + db_kwargs = db_kwargs.copy() + try: + driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout, allow_retry=False, + allow_pooling=False) + driver.connect() + pre_condition_check = PRECONDITION_VALIDATION.get(make_url(url).drivername) + if pre_condition_check: + pre_condition_check(driver) + + finally: + try: + driver.close() + except: + pass + + def _wrap_for_retry(driver): return type('Retrying' + driver.__class__.__name__, (RetryOperationalError, driver), {}) diff --git a/util/config/validators/validate_database.py b/util/config/validators/validate_database.py index 30fcc6d0e..dfa267397 100644 --- a/util/config/validators/validate_database.py +++ b/util/config/validators/validate_database.py @@ -1,6 +1,6 @@ from peewee import OperationalError -from data.database import validate_database_url +from data.database import validate_database_precondition from util.config.validators import BaseValidator, ConfigValidationException class DatabaseValidator(BaseValidator): @@ -12,7 +12,7 @@ class DatabaseValidator(BaseValidator): config = validator_context.config try: - validate_database_url(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {})) + validate_database_precondition(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {})) except OperationalError as ex: if ex.args and len(ex.args) > 1: raise ConfigValidationException(ex.args[1]) diff --git a/util/validation.py b/util/validation.py index 235c4afa5..eabe290c5 100644 --- a/util/validation.py +++ b/util/validation.py @@ -4,6 +4,7 @@ import json import anunidecode # Don't listen to pylint's lies. This import is required. +from peewee import OperationalError INVALID_PASSWORD_MESSAGE = 'Invalid password, password must be at least ' + \ '8 characters and contain no whitespace.' @@ -89,3 +90,12 @@ def is_json(value): except (TypeError, ValueError): return False return False + + +def validate_postgres_precondition(driver): + cursor = driver.execute_sql("SELECT extname FROM pg_extension", ("public",)) + if 'pg_trgm' not in [extname for extname, in cursor.fetchall()]: + raise OperationalError(""" + "pg_trgm" extension does not exists in the database. + Please run `CREATE EXTENSION IF NOT EXISTS pg_trgm;` as superuser on this database. + """)