diff --git a/config_app/config_endpoints/api/suconfig.py b/config_app/config_endpoints/api/suconfig.py index 7c9f18c09..62615af52 100644 --- a/config_app/config_endpoints/api/suconfig.py +++ b/config_app/config_endpoints/api/suconfig.py @@ -102,7 +102,7 @@ class SuperUserRegistryStatus(ApiResource): } config = config_provider.get_config() - if config and config['SETUP_COMPLETE']: + if config and config.get('SETUP_COMPLETE'): return { 'status': 'config' } diff --git a/config_app/config_test/test_suconfig_api.py b/config_app/config_test/test_suconfig_api.py index af85fb4d8..408b96a8b 100644 --- a/config_app/config_test/test_suconfig_api.py +++ b/config_app/config_test/test_suconfig_api.py @@ -1,4 +1,5 @@ import unittest +import mock from data.database import User from data import model @@ -28,11 +29,40 @@ class FreshConfigProvider(object): class TestSuperUserRegistryStatus(ApiTestCase): - def test_registry_status(self): + def test_registry_status_no_config(self): with FreshConfigProvider(): json = self.getJsonResponse(SuperUserRegistryStatus) self.assertEquals('config-db', json['status']) + @mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=False)) + def test_registry_status_no_database(self): + with FreshConfigProvider(): + config_provider.save_config({'key': 'value'}) + json = self.getJsonResponse(SuperUserRegistryStatus) + self.assertEquals('setup-db', json['status']) + + @mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True)) + def test_registry_status_db_has_superuser(self): + with FreshConfigProvider(): + config_provider.save_config({'key': 'value'}) + json = self.getJsonResponse(SuperUserRegistryStatus) + self.assertEquals('config', json['status']) + + @mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True)) + @mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=False)) + def test_registry_status_db_no_superuser(self): + with FreshConfigProvider(): + config_provider.save_config({'key': 'value'}) + json = self.getJsonResponse(SuperUserRegistryStatus) + self.assertEquals('create-superuser', json['status']) + + @mock.patch("config_app.config_endpoints.api.suconfig.database_is_valid", mock.Mock(return_value=True)) + @mock.patch("config_app.config_endpoints.api.suconfig.database_has_users", mock.Mock(return_value=True)) + def test_registry_status_setup_complete(self): + with FreshConfigProvider(): + config_provider.save_config({'key': 'value', 'SETUP_COMPLETE': True}) + json = self.getJsonResponse(SuperUserRegistryStatus) + self.assertEquals('config', json['status']) class TestSuperUserConfigFile(ApiTestCase): def test_get_superuser_invalid_filename(self): diff --git a/data/database.py b/data/database.py index a2fb8d632..09971b0dd 100644 --- a/data/database.py +++ b/data/database.py @@ -30,6 +30,7 @@ from data.fields import (ResumableSHA256Field, ResumableSHA1Field, JSONField, Ba from data.text import match_mysql, match_like from data.read_slave import ReadSlaveModel from util.names import urn_generator +from util.validation import validate_postgres_precondition logger = logging.getLogger(__name__) @@ -70,6 +71,12 @@ SCHEME_RANDOM_FUNCTION = { } +PRECONDITION_VALIDATION = { + 'postgresql': validate_postgres_precondition, + 'postgresql+psycopg2': validate_postgres_precondition, +} + + _EXTRA_ARGS = { 'mysql': dict(charset='utf8mb4'), 'mysql+pymysql': dict(charset='utf8mb4'), @@ -284,6 +291,25 @@ def validate_database_url(url, db_kwargs, connect_timeout=5): pass +def validate_database_precondition(url, db_kwargs, connect_timeout=5): + """ Validates that we can connect to the given database URL and the database meets our + precondition. Raises an exception if the validation fails. """ + db_kwargs = db_kwargs.copy() + try: + driver = _db_from_url(url, db_kwargs, connect_timeout=connect_timeout, allow_retry=False, + allow_pooling=False) + driver.connect() + pre_condition_check = PRECONDITION_VALIDATION.get(make_url(url).drivername) + if pre_condition_check: + pre_condition_check(driver) + + finally: + try: + driver.close() + except: + pass + + def _wrap_for_retry(driver): return type('Retrying' + driver.__class__.__name__, (RetryOperationalError, driver), {}) diff --git a/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py b/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py index 281e148c8..36304ccfe 100644 --- a/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py +++ b/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py @@ -15,9 +15,6 @@ import sqlalchemy as sa from sqlalchemy.dialects import mysql def upgrade(tables, tester): - if op.get_bind().engine.name == 'postgresql': - op.execute('CREATE EXTENSION IF NOT EXISTS pg_trgm') - # ### commands auto generated by Alembic - please adjust! ### op.create_index('repository_description__fulltext', 'repository', ['description'], unique=False, postgresql_using='gin', postgresql_ops={'description': 'gin_trgm_ops'}, mysql_prefix='FULLTEXT') op.create_index('repository_name__fulltext', 'repository', ['name'], unique=False, postgresql_using='gin', postgresql_ops={'name': 'gin_trgm_ops'}, mysql_prefix='FULLTEXT') diff --git a/scripts/ci b/scripts/ci index 1d21607dc..bbbbe7322 100755 --- a/scripts/ci +++ b/scripts/ci @@ -12,6 +12,7 @@ IMAGE_TAR="${CACHE_DIR}/${IMAGE}-${IMAGE_TAG}.tar.gz" MYSQL_IMAGE="mysql:5.7" POSTGRES_IMAGE="postgres:9.6" +POSTGRES_CONTAINER="test_postgres" export MYSQL_ROOT_PASSWORD="quay" export MYSQL_USER="quay" @@ -110,13 +111,18 @@ postgres_ping() { postgres_start() { - docker run --net=host -d -e POSTGRES_USER -e POSTGRES_PASSWORD \ + docker run --name="${POSTGRES_CONTAINER}" --net=host -d -e POSTGRES_USER -e POSTGRES_PASSWORD \ -e POSTGRES_DB "${POSTGRES_IMAGE}" if ! (sleep 10 && postgres_ping); then echo "PostgreSQL failed to respond in time." exit 1 - fi + fi +} + + +postgres_init() { + docker exec "${POSTGRES_CONTAINER}" psql -U "${POSTGRES_USER}" -d "${POSTGRES_DB}" -c 'CREATE EXTENSION IF NOT EXISTS pg_trgm;' } @@ -129,6 +135,7 @@ postgres() { load_image postgres_start + postgres_init quay_run make full-db-test } diff --git a/util/config/validators/validate_database.py b/util/config/validators/validate_database.py index 30fcc6d0e..dfa267397 100644 --- a/util/config/validators/validate_database.py +++ b/util/config/validators/validate_database.py @@ -1,6 +1,6 @@ from peewee import OperationalError -from data.database import validate_database_url +from data.database import validate_database_precondition from util.config.validators import BaseValidator, ConfigValidationException class DatabaseValidator(BaseValidator): @@ -12,7 +12,7 @@ class DatabaseValidator(BaseValidator): config = validator_context.config try: - validate_database_url(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {})) + validate_database_precondition(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {})) except OperationalError as ex: if ex.args and len(ex.args) > 1: raise ConfigValidationException(ex.args[1]) diff --git a/util/validation.py b/util/validation.py index 235c4afa5..eabe290c5 100644 --- a/util/validation.py +++ b/util/validation.py @@ -4,6 +4,7 @@ import json import anunidecode # Don't listen to pylint's lies. This import is required. +from peewee import OperationalError INVALID_PASSWORD_MESSAGE = 'Invalid password, password must be at least ' + \ '8 characters and contain no whitespace.' @@ -89,3 +90,12 @@ def is_json(value): except (TypeError, ValueError): return False return False + + +def validate_postgres_precondition(driver): + cursor = driver.execute_sql("SELECT extname FROM pg_extension", ("public",)) + if 'pg_trgm' not in [extname for extname, in cursor.fetchall()]: + raise OperationalError(""" + "pg_trgm" extension does not exists in the database. + Please run `CREATE EXTENSION IF NOT EXISTS pg_trgm;` as superuser on this database. + """)