From f933b3e295a72a9d8746ed907066204ae22545a4 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Mon, 30 Jan 2017 16:24:58 -0500 Subject: [PATCH] Pull out database validation into validator class --- test/test_validate_config.py | 6 ------ util/config/validator.py | 3 ++- util/config/validators/__init__.py | 20 +++++++++++++++++++ util/config/validators/database.py | 18 +++++++++++++++++ util/config/validators/test/test_database.py | 21 ++++++++++++++++++++ 5 files changed, 61 insertions(+), 7 deletions(-) create mode 100644 util/config/validators/__init__.py create mode 100644 util/config/validators/database.py create mode 100644 util/config/validators/test/test_database.py diff --git a/test/test_validate_config.py b/test/test_validate_config.py index 2ad5cf2a8..20d8b3d28 100644 --- a/test/test_validate_config.py +++ b/test/test_validate_config.py @@ -46,12 +46,6 @@ class TestValidateConfig(unittest.TestCase): # Skip mail. self.validated.add('mail') - def test_validate_database(self): - with self.assertRaisesRegexp(Exception, 'database not properly initialized'): - self.validate('database', { - 'DB_URI': 'mysql://somehost', - }) - def test_validate_jwt(self): with self.assertRaisesRegexp(ConfigValidationException, 'Missing JWT Verification endpoint'): self.validate('jwt', { diff --git a/util/config/validator.py b/util/config/validator.py index 40b4393e2..e37575400 100644 --- a/util/config/validator.py +++ b/util/config/validator.py @@ -29,6 +29,7 @@ from util.secscan.api import SecurityScannerAPI from util.registry.torrent import torrent_jwt from util.security.signing import SIGNING_ENGINES from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException +from util.config.validators.database import DatabaseValidator logger = logging.getLogger(__name__) @@ -522,7 +523,7 @@ def _validate_bittorrent(config, user_obj, _): VALIDATORS = { - 'database': _validate_database, + DatabaseValidator.name: DatabaseValidator.validate, 'redis': _validate_redis, 'registry-storage': _validate_registry_storage, 'mail': _validate_mailing, diff --git a/util/config/validators/__init__.py b/util/config/validators/__init__.py new file mode 100644 index 000000000..a3edeeb12 --- /dev/null +++ b/util/config/validators/__init__.py @@ -0,0 +1,20 @@ +from abc import ABCMeta, abstractmethod, abstractproperty +from six import add_metaclass + +class ConfigValidationException(Exception): + """ Exception raised when the configuration fails to validate for a known reason. """ + pass + + +@add_metaclass(ABCMeta) +class BaseValidator(object): + @abstractproperty + def name(self): + """ The key for the validation API. """ + pass + + @classmethod + @abstractmethod + def validate(cls, config, user, user_password): + """ Raises Exception if failure to validate. """ + pass diff --git a/util/config/validators/database.py b/util/config/validators/database.py new file mode 100644 index 000000000..5fb27fa80 --- /dev/null +++ b/util/config/validators/database.py @@ -0,0 +1,18 @@ +from peewee import OperationalError + +from data.database import validate_database_url +from util.config.validators import BaseValidator, ConfigValidationException + +class DatabaseValidator(BaseValidator): + name = "database" + + @classmethod + def validate(cls, config, user, user_password): + """ Validates connecting to the database. """ + try: + validate_database_url(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {})) + except OperationalError as ex: + if ex.args and len(ex.args) > 1: + raise ConfigValidationException(ex.args[1]) + else: + raise ex diff --git a/util/config/validators/test/test_database.py b/util/config/validators/test/test_database.py new file mode 100644 index 000000000..74612641e --- /dev/null +++ b/util/config/validators/test/test_database.py @@ -0,0 +1,21 @@ +import pytest + +from util.config.validators import ConfigValidationException +from util.config.validators.database import DatabaseValidator + +@pytest.mark.parametrize('unvalidated_config,user,user_password,expected', [ + (None, None, None, TypeError), + ({}, None, None, KeyError), + ({'DB_URI': 'sqlite:///:memory:'}, None, None, None), + ({'DB_URI': 'invalid:///:memory:'}, None, None, KeyError), + ({'DB_NOTURI': 'sqlite:///:memory:'}, None, None, KeyError), + ({'DB_URI': 'mysql:///someinvalid'}, None, None, ConfigValidationException), +]) +def test_validate_database(unvalidated_config, user, user_password, expected): + validator = DatabaseValidator() + + if expected is not None: + with pytest.raises(expected): + validator.validate(unvalidated_config, user, user_password) + else: + validator.validate(unvalidated_config, user, user_password)