Pull out database validation into validator class

This commit is contained in:
Joseph Schorr 2017-01-30 16:24:58 -05:00
parent 484977f728
commit f933b3e295
5 changed files with 61 additions and 7 deletions

View file

@ -46,12 +46,6 @@ class TestValidateConfig(unittest.TestCase):
# Skip mail.
self.validated.add('mail')
def test_validate_database(self):
with self.assertRaisesRegexp(Exception, 'database not properly initialized'):
self.validate('database', {
'DB_URI': 'mysql://somehost',
})
def test_validate_jwt(self):
with self.assertRaisesRegexp(ConfigValidationException, 'Missing JWT Verification endpoint'):
self.validate('jwt', {

View file

@ -29,6 +29,7 @@ from util.secscan.api import SecurityScannerAPI
from util.registry.torrent import torrent_jwt
from util.security.signing import SIGNING_ENGINES
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
from util.config.validators.database import DatabaseValidator
logger = logging.getLogger(__name__)
@ -522,7 +523,7 @@ def _validate_bittorrent(config, user_obj, _):
VALIDATORS = {
'database': _validate_database,
DatabaseValidator.name: DatabaseValidator.validate,
'redis': _validate_redis,
'registry-storage': _validate_registry_storage,
'mail': _validate_mailing,

View file

@ -0,0 +1,20 @@
from abc import ABCMeta, abstractmethod, abstractproperty
from six import add_metaclass
class ConfigValidationException(Exception):
""" Exception raised when the configuration fails to validate for a known reason. """
pass
@add_metaclass(ABCMeta)
class BaseValidator(object):
@abstractproperty
def name(self):
""" The key for the validation API. """
pass
@classmethod
@abstractmethod
def validate(cls, config, user, user_password):
""" Raises Exception if failure to validate. """
pass

View file

@ -0,0 +1,18 @@
from peewee import OperationalError
from data.database import validate_database_url
from util.config.validators import BaseValidator, ConfigValidationException
class DatabaseValidator(BaseValidator):
name = "database"
@classmethod
def validate(cls, config, user, user_password):
""" Validates connecting to the database. """
try:
validate_database_url(config['DB_URI'], config.get('DB_CONNECTION_ARGS', {}))
except OperationalError as ex:
if ex.args and len(ex.args) > 1:
raise ConfigValidationException(ex.args[1])
else:
raise ex

View file

@ -0,0 +1,21 @@
import pytest
from util.config.validators import ConfigValidationException
from util.config.validators.database import DatabaseValidator
@pytest.mark.parametrize('unvalidated_config,user,user_password,expected', [
(None, None, None, TypeError),
({}, None, None, KeyError),
({'DB_URI': 'sqlite:///:memory:'}, None, None, None),
({'DB_URI': 'invalid:///:memory:'}, None, None, KeyError),
({'DB_NOTURI': 'sqlite:///:memory:'}, None, None, KeyError),
({'DB_URI': 'mysql:///someinvalid'}, None, None, ConfigValidationException),
])
def test_validate_database(unvalidated_config, user, user_password, expected):
validator = DatabaseValidator()
if expected is not None:
with pytest.raises(expected):
validator.validate(unvalidated_config, user, user_password)
else:
validator.validate(unvalidated_config, user, user_password)