Pull out ssl validation into validator class

This commit is contained in:
Joseph Schorr 2017-02-15 15:17:07 -05:00
parent e76b95f0e6
commit 620e377faf
3 changed files with 123 additions and 54 deletions

View file

@ -0,0 +1,62 @@
import pytest
from mock import patch
from tempfile import NamedTemporaryFile
from util.config.validators import ConfigValidationException
from util.config.validators.validate_ssl import SSLValidator, SSL_FILENAMES
from test.test_ssl_util import generate_test_cert
@pytest.mark.parametrize('unvalidated_config', [
({}),
({'PREFERRED_URL_SCHEME': 'http'}),
({'PREFERRED_URL_SCHEME': 'https', 'EXTERNAL_TLS_TERMINATION': True}),
])
def test_skip_validate_ssl(unvalidated_config):
validator = SSLValidator()
validator.validate(unvalidated_config, None, None)
@pytest.mark.parametrize('cert, expected_error, error_message', [
('invalidcert', ConfigValidationException, 'Could not load SSL certificate: no start line'),
(generate_test_cert(hostname='someserver'), None, None),
(generate_test_cert(hostname='invalidserver'), ConfigValidationException,
'Supported names "invalidserver" in SSL cert do not match server hostname "someserver"'),
])
def test_validate_ssl(cert, expected_error, error_message):
with NamedTemporaryFile(delete=False) as cert_file:
cert_file.write(cert[0])
cert_file.seek(0)
with NamedTemporaryFile(delete=False) as key_file:
key_file.write(cert[1])
key_file.seek(0)
def return_true(filename):
return True
def get_volume_file(filename):
if filename == SSL_FILENAMES[0]:
return open(cert_file.name)
if filename == SSL_FILENAMES[1]:
return open(key_file.name)
return None
config = {
'PREFERRED_URL_SCHEME': 'https',
'SERVER_HOSTNAME': 'someserver',
}
with patch('app.config_provider.volume_file_exists', return_true):
with patch('app.config_provider.get_volume_file', get_volume_file):
validator = SSLValidator()
if expected_error is not None:
with pytest.raises(expected_error) as ipe:
validator.validate(config, None, None)
assert ipe.value.message == error_message
else:
validator.validate(config, None, None)

View file

@ -0,0 +1,59 @@
from app import config_provider
from util.config.validators import BaseValidator, ConfigValidationException
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
SSL_FILENAMES = ['ssl.cert', 'ssl.key']
class SSLValidator(BaseValidator):
name = "ssl"
@classmethod
def validate(cls, config, user, user_password):
""" Validates the SSL configuration (if enabled). """
# Skip if non-SSL.
if config.get('PREFERRED_URL_SCHEME', 'http') != 'https':
return
# Skip if externally terminated.
if config.get('EXTERNAL_TLS_TERMINATION', False) is True:
return
# Verify that we have all the required SSL files.
for filename in SSL_FILENAMES:
if not config_provider.volume_file_exists(filename):
raise ConfigValidationException('Missing required SSL file: %s' % filename)
# Read the contents of the SSL certificate.
with config_provider.get_volume_file(SSL_FILENAMES[0]) as f:
cert_contents = f.read()
# Validate the certificate.
try:
certificate = load_certificate(cert_contents)
except CertInvalidException as cie:
raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message)
# Verify the certificate has not expired.
if certificate.expired:
raise ConfigValidationException('The specified SSL certificate has expired.')
# Verify the hostname matches the name in the certificate.
if not certificate.matches_name(config['SERVER_HOSTNAME']):
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
(', '.join(list(certificate.names)), config['SERVER_HOSTNAME']))
raise ConfigValidationException(msg)
# Verify the private key against the certificate.
private_key_path = None
with config_provider.get_volume_file(SSL_FILENAMES[1]) as f:
private_key_path = f.name
if not private_key_path:
# Only in testing.
return
try:
certificate.validate_private_key(private_key_path)
except KeyInvalidException as kie:
raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message)