diff --git a/util/config/validator.py b/util/config/validator.py index 8d7c36374..c2ad42da8 100644 --- a/util/config/validator.py +++ b/util/config/validator.py @@ -12,7 +12,6 @@ from data.users import LDAP_CERT_FILENAME from oauth.services.github import GithubOAuthService from oauth.services.google import GoogleOAuthService from oauth.services.gitlab import GitLabOAuthService -from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException from util.config.validators.validate_database import DatabaseValidator from util.config.validators.validate_redis import RedisValidator @@ -24,6 +23,7 @@ from util.config.validators.validate_jwt import JWTAuthValidator from util.config.validators.validate_secscan import SecurityScannerValidator from util.config.validators.validate_signer import SignerValidator from util.config.validators.validate_torrent import BittorrentValidator +from util.config.validators.validate_ssl import SSLValidator, SSL_FILENAMES logger = logging.getLogger(__name__) @@ -33,7 +33,6 @@ class ConfigValidationException(Exception): # Note: Only add files required for HTTPS to the SSL_FILESNAMES list. -SSL_FILENAMES = ['ssl.cert', 'ssl.key'] DB_SSL_FILENAMES = ['database.pem'] JWT_FILENAMES = ['jwt-authn.cert'] ACI_CERT_FILENAMES = ['signing-public.gpg', 'signing-private.gpg'] @@ -180,57 +179,6 @@ def _validate_google_login(config, user_obj, _): raise ConfigValidationException('Invalid client id or client secret') -def _validate_ssl(config, user_obj, _): - """ Validates the SSL configuration (if enabled). """ - - # Skip if non-SSL. - if config.get('PREFERRED_URL_SCHEME', 'http') != 'https': - return - - # Skip if externally terminated. - if config.get('EXTERNAL_TLS_TERMINATION', False) is True: - return - - # Verify that we have all the required SSL files. - for filename in SSL_FILENAMES: - if not config_provider.volume_file_exists(filename): - raise ConfigValidationException('Missing required SSL file: %s' % filename) - - # Read the contents of the SSL certificate. - with config_provider.get_volume_file(SSL_FILENAMES[0]) as f: - cert_contents = f.read() - - # Validate the certificate. - try: - certificate = load_certificate(cert_contents) - except CertInvalidException as cie: - raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message) - - # Verify the certificate has not expired. - if certificate.expired: - raise ConfigValidationException('The specified SSL certificate has expired.') - - # Verify the hostname matches the name in the certificate. - if not certificate.matches_name(config['SERVER_HOSTNAME']): - msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' % - (', '.join(list(certificate.names)), config['SERVER_HOSTNAME'])) - raise ConfigValidationException(msg) - - # Verify the private key against the certificate. - private_key_path = None - with config_provider.get_volume_file(SSL_FILENAMES[1]) as f: - private_key_path = f.name - - if not private_key_path: - # Only in testing. - return - - try: - certificate.validate_private_key(private_key_path) - except KeyInvalidException as kie: - raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message) - - VALIDATORS = { DatabaseValidator.name: DatabaseValidator.validate, RedisValidator.name: RedisValidator.validate, @@ -241,7 +189,7 @@ VALIDATORS = { 'gitlab-trigger': _validate_gitlab, 'bitbucket-trigger': _validate_bitbucket, 'google-login': _validate_google_login, - 'ssl': _validate_ssl, + SSLValidator.name: SSLValidator.validate, LDAPValidator.name: LDAPValidator.validate, JWTAuthValidator.name: JWTAuthValidator.validate, KeystoneValidator.name: KeystoneValidator.validate, diff --git a/util/config/validators/test/test_validate_ssl.py b/util/config/validators/test/test_validate_ssl.py new file mode 100644 index 000000000..ee5a4aa22 --- /dev/null +++ b/util/config/validators/test/test_validate_ssl.py @@ -0,0 +1,62 @@ +import pytest + +from mock import patch +from tempfile import NamedTemporaryFile + +from util.config.validators import ConfigValidationException +from util.config.validators.validate_ssl import SSLValidator, SSL_FILENAMES +from test.test_ssl_util import generate_test_cert + +@pytest.mark.parametrize('unvalidated_config', [ + ({}), + ({'PREFERRED_URL_SCHEME': 'http'}), + ({'PREFERRED_URL_SCHEME': 'https', 'EXTERNAL_TLS_TERMINATION': True}), +]) +def test_skip_validate_ssl(unvalidated_config): + validator = SSLValidator() + validator.validate(unvalidated_config, None, None) + + +@pytest.mark.parametrize('cert, expected_error, error_message', [ + ('invalidcert', ConfigValidationException, 'Could not load SSL certificate: no start line'), + (generate_test_cert(hostname='someserver'), None, None), + (generate_test_cert(hostname='invalidserver'), ConfigValidationException, + 'Supported names "invalidserver" in SSL cert do not match server hostname "someserver"'), +]) +def test_validate_ssl(cert, expected_error, error_message): + with NamedTemporaryFile(delete=False) as cert_file: + cert_file.write(cert[0]) + cert_file.seek(0) + + with NamedTemporaryFile(delete=False) as key_file: + key_file.write(cert[1]) + key_file.seek(0) + + def return_true(filename): + return True + + def get_volume_file(filename): + if filename == SSL_FILENAMES[0]: + return open(cert_file.name) + + if filename == SSL_FILENAMES[1]: + return open(key_file.name) + + return None + + config = { + 'PREFERRED_URL_SCHEME': 'https', + 'SERVER_HOSTNAME': 'someserver', + } + + with patch('app.config_provider.volume_file_exists', return_true): + with patch('app.config_provider.get_volume_file', get_volume_file): + validator = SSLValidator() + + if expected_error is not None: + with pytest.raises(expected_error) as ipe: + validator.validate(config, None, None) + + assert ipe.value.message == error_message + else: + validator.validate(config, None, None) diff --git a/util/config/validators/validate_ssl.py b/util/config/validators/validate_ssl.py new file mode 100644 index 000000000..ea1ae3188 --- /dev/null +++ b/util/config/validators/validate_ssl.py @@ -0,0 +1,59 @@ +from app import config_provider +from util.config.validators import BaseValidator, ConfigValidationException +from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException + +SSL_FILENAMES = ['ssl.cert', 'ssl.key'] + +class SSLValidator(BaseValidator): + name = "ssl" + + @classmethod + def validate(cls, config, user, user_password): + """ Validates the SSL configuration (if enabled). """ + + # Skip if non-SSL. + if config.get('PREFERRED_URL_SCHEME', 'http') != 'https': + return + + # Skip if externally terminated. + if config.get('EXTERNAL_TLS_TERMINATION', False) is True: + return + + # Verify that we have all the required SSL files. + for filename in SSL_FILENAMES: + if not config_provider.volume_file_exists(filename): + raise ConfigValidationException('Missing required SSL file: %s' % filename) + + # Read the contents of the SSL certificate. + with config_provider.get_volume_file(SSL_FILENAMES[0]) as f: + cert_contents = f.read() + + # Validate the certificate. + try: + certificate = load_certificate(cert_contents) + except CertInvalidException as cie: + raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message) + + # Verify the certificate has not expired. + if certificate.expired: + raise ConfigValidationException('The specified SSL certificate has expired.') + + # Verify the hostname matches the name in the certificate. + if not certificate.matches_name(config['SERVER_HOSTNAME']): + msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' % + (', '.join(list(certificate.names)), config['SERVER_HOSTNAME'])) + raise ConfigValidationException(msg) + + # Verify the private key against the certificate. + private_key_path = None + with config_provider.get_volume_file(SSL_FILENAMES[1]) as f: + private_key_path = f.name + + if not private_key_path: + # Only in testing. + return + + try: + certificate.validate_private_key(private_key_path) + except KeyInvalidException as kie: + raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message)