From 3a24871422c6bb724acc72d6776c1c4483d93a96 Mon Sep 17 00:00:00 2001 From: Joseph Schorr Date: Fri, 9 Dec 2016 18:31:02 -0500 Subject: [PATCH] Add SSL certificate utility and tests --- test/test_ssl_util.py | 105 ++++++++++++++++++++++++++++ test/test_validate_config.py | 2 - util/config/provider/k8sprovider.py | 2 +- util/config/validator.py | 61 +++++----------- util/security/ssl.py | 81 +++++++++++++++++++++ 5 files changed, 205 insertions(+), 46 deletions(-) create mode 100644 test/test_ssl_util.py create mode 100644 util/security/ssl.py diff --git a/test/test_ssl_util.py b/test/test_ssl_util.py new file mode 100644 index 000000000..69bc0a437 --- /dev/null +++ b/test/test_ssl_util.py @@ -0,0 +1,105 @@ +import unittest + +from tempfile import NamedTemporaryFile +from OpenSSL import crypto + +from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException + +class TestSSLCertificate(unittest.TestCase): + def _generate_cert(self, hostname='somehostname', san_list=None, expires=1000000): + # Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/ + # Create a key pair. + k = crypto.PKey() + k.generate_key(crypto.TYPE_RSA, 1024) + + # Create a self-signed cert. + cert = crypto.X509() + cert.get_subject().CN = hostname + + # Add the subjectAltNames (if necessary). + if san_list is not None: + cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))]) + + cert.set_serial_number(1000) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(expires) + cert.set_issuer(cert.get_subject()) + + cert.set_pubkey(k) + cert.sign(k, 'sha1') + + # Dump the certificate and private key in PEM format. + cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) + key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k) + + return (cert_data, key_data) + + def test_load_certificate(self): + # Try loading an invalid certificate. + with self.assertRaisesRegexp(CertInvalidException, 'no start line'): + load_certificate('someinvalidcontents') + + # Load a valid certificate. + (public_key_data, _) = self._generate_cert() + + cert = load_certificate(public_key_data) + self.assertFalse(cert.expired) + self.assertEquals(set(['somehostname']), cert.names) + self.assertTrue(cert.matches_name('somehostname')) + + def test_expired_certificate(self): + (public_key_data, _) = self._generate_cert(expires=-100) + + cert = load_certificate(public_key_data) + self.assertTrue(cert.expired) + + def test_hostnames(self): + (public_key_data, _) = self._generate_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz']) + cert = load_certificate(public_key_data) + self.assertEquals(set(['foo', 'bar', 'baz']), cert.names) + + for name in cert.names: + self.assertTrue(cert.matches_name(name)) + + def test_nondns_hostnames(self): + (public_key_data, _) = self._generate_cert(hostname='foo', san_list=['URI:yarg']) + cert = load_certificate(public_key_data) + self.assertEquals(set(['foo']), cert.names) + + def test_validate_private_key(self): + (public_key_data, private_key_data) = self._generate_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write(private_key_data) + private_key.seek(0) + + cert = load_certificate(public_key_data) + cert.validate_private_key(private_key.name) + + def test_invalid_private_key(self): + (public_key_data, _) = self._generate_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write('somerandomdata') + private_key.seek(0) + + cert = load_certificate(public_key_data) + with self.assertRaisesRegexp(KeyInvalidException, 'no start line'): + cert.validate_private_key(private_key.name) + + def test_mismatch_private_key(self): + (public_key_data, _) = self._generate_cert() + (_, private_key_data) = self._generate_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write(private_key_data) + private_key.seek(0) + + cert = load_certificate(public_key_data) + with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'): + cert.validate_private_key(private_key.name) + + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_validate_config.py b/test/test_validate_config.py index 27cca658d..2ad5cf2a8 100644 --- a/test/test_validate_config.py +++ b/test/test_validate_config.py @@ -143,8 +143,6 @@ class TestValidateConfig(unittest.TestCase): 'PREFERRED_URL_SCHEME': 'https', }) - # TODO(jschorr): Add SSL verification tests once file lookup is fixed. - def test_validate_keystone(self): with self.assertRaisesRegexp(ConfigValidationException, 'Verification of superuser someuser failed'): diff --git a/util/config/provider/k8sprovider.py b/util/config/provider/k8sprovider.py index 7552eb839..0feb6ca60 100644 --- a/util/config/provider/k8sprovider.py +++ b/util/config/provider/k8sprovider.py @@ -79,7 +79,7 @@ class KubernetesConfigProvider(FileConfigProvider): # as an error, as it seems to be a common issue. namespace_url = 'namespaces/%s' % (QE_NAMESPACE) response = self._execute_k8s_api('GET', namespace_url) - if response.status_code / 100 != 2: + if response.status_code // 100 != 2: msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE raise CannotWriteConfigException(msg) diff --git a/util/config/validator.py b/util/config/validator.py index c3710c7d0..60e44b443 100644 --- a/util/config/validator.py +++ b/util/config/validator.py @@ -3,10 +3,8 @@ import subprocess import time from StringIO import StringIO -from fnmatch import fnmatch from hashlib import sha1 -import OpenSSL import ldap import peewee import redis @@ -28,6 +26,7 @@ from util.config.oauth import GoogleOAuthConfig, GithubOAuthConfig, GitLabOAuthC from util.secscan.api import SecurityScannerAPI from util.registry.torrent import torrent_jwt from util.security.signing import SIGNING_ENGINES +from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException logger = logging.getLogger(__name__) @@ -257,22 +256,32 @@ def _validate_ssl(config, user_obj, _): if config.get('EXTERNAL_TLS_TERMINATION', False) is True: return + # Verify that we have all the required SSL files. for filename in SSL_FILENAMES: if not config_provider.volume_file_exists(filename): raise ConfigValidationException('Missing required SSL file: %s' % filename) + # Read the contents of the SSL certificate. with config_provider.get_volume_file(SSL_FILENAMES[0]) as f: cert_contents = f.read() # Validate the certificate. try: - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents) - except: - raise ConfigValidationException('Could not parse certificate file. Is it a valid PEM certificate?') + certificate = load_certificate(cert_contents) + except CertInvalidException as cie: + raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message) - if cert.has_expired(): + # Verify the certificate has not expired. + if certificate.expired: raise ConfigValidationException('The specified SSL certificate has expired.') + # Verify the hostname matches the name in the certificate. + if not certificate.matches_name(config['SERVER_HOSTNAME']): + msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' % + (', '.join(list(certificate.names)), config['SERVER_HOSTNAME'])) + raise ConfigValidationException(msg) + + # Verify the private key against the certificate. private_key_path = None with config_provider.get_volume_file(SSL_FILENAMES[1]) as f: private_key_path = f.name @@ -281,44 +290,10 @@ def _validate_ssl(config, user_obj, _): # Only in testing. return - # Validate the private key with the certificate. - context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD) - context.use_certificate(cert) - try: - context.use_privatekey_file(private_key_path) - except: - raise ConfigValidationException('Could not parse key file. Is it a valid PEM private key?') - - try: - context.check_privatekey() - except OpenSSL.SSL.Error as e: - raise ConfigValidationException('SSL key failed to validate: %s' % str(e)) - - # Verify the hostname matches the name in the certificate. - common_name = cert.get_subject().commonName - if common_name is None: - raise ConfigValidationException('Missing CommonName (CN) from SSL certificate') - - # Build the list of allowed host patterns. - hosts = set([common_name]) - - # Find the DNS extension, if any. - for i in range(0, cert.get_extension_count()): - ext = cert.get_extension(i) - if ext.get_short_name() == 'subjectAltName': - value = str(ext) - hosts.update([host.strip()[4:] for host in value.split(',')]) - - # Check each host. - for host in hosts: - if fnmatch(config['SERVER_HOSTNAME'], host): - return - - msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' % - (', '.join(list(hosts)), config['SERVER_HOSTNAME'])) - raise ConfigValidationException(msg) - + certificate.validate_private_key(private_key_path) + except KeyInvalidException as kie: + raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message) def _validate_ldap(config, user_obj, password): diff --git a/util/security/ssl.py b/util/security/ssl.py new file mode 100644 index 000000000..7f0534c9a --- /dev/null +++ b/util/security/ssl.py @@ -0,0 +1,81 @@ +from fnmatch import fnmatch + +import OpenSSL + +class CertInvalidException(Exception): + """ Exception raised when a certificate could not be parsed/loaded. """ + pass + +class KeyInvalidException(Exception): + """ Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """ + pass + + +def load_certificate(cert_contents): + """ Loads the certificate from the given contents and returns it or raises a CertInvalidException + on failure. + """ + try: + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents) + return SSLCertificate(cert) + except OpenSSL.crypto.Error as ex: + raise CertInvalidException(ex.message[0][2]) + + +_SUBJECT_ALT_NAME = 'subjectAltName' + +class SSLCertificate(object): + """ Helper class for easier working with SSL certificates. """ + def __init__(self, openssl_cert): + self.openssl_cert = openssl_cert + + def validate_private_key(self, private_key_path): + """ Validates that the private key found at the given file path applies to this certificate. + Raises a KeyInvalidException on failure. + """ + context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD) + context.use_certificate(self.openssl_cert) + + try: + context.use_privatekey_file(private_key_path) + context.check_privatekey() + except OpenSSL.SSL.Error as ex: + raise KeyInvalidException(ex.message[0][2]) + + def matches_name(self, check_name): + """ Returns true if this SSL certificate matches the given DNS hostname. """ + for dns_name in self.names: + if fnmatch(dns_name, check_name): + return True + + return False + + @property + def expired(self): + """ Returns whether the SSL certificate has expired. """ + return self.openssl_cert.has_expired() + + @property + def common_name(self): + """ Returns the defined common name for the certificate, if any. """ + return self.openssl_cert.get_subject().commonName + + @property + def names(self): + """ Returns all the DNS named to which the certificate applies. May be empty. """ + dns_names = set() + common_name = self.common_name + if common_name is not None: + dns_names.add(common_name) + + # Find the DNS extension, if any. + for i in range(0, self.openssl_cert.get_extension_count()): + ext = self.openssl_cert.get_extension(i) + if ext.get_short_name() == _SUBJECT_ALT_NAME: + value = str(ext) + for san_name in value.split(','): + san_name_trimmed = san_name.strip() + if san_name_trimmed.startswith('DNS:'): + dns_names.add(san_name_trimmed[4:]) + + return dns_names