diff --git a/test/test_ssl_util.py b/test/test_ssl_util.py new file mode 100644 index 000000000..69bc0a437 --- /dev/null +++ b/test/test_ssl_util.py @@ -0,0 +1,105 @@ +import unittest + +from tempfile import NamedTemporaryFile +from OpenSSL import crypto + +from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException + +class TestSSLCertificate(unittest.TestCase): + def _generate_cert(self, hostname='somehostname', san_list=None, expires=1000000): + # Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/ + # Create a key pair. + k = crypto.PKey() + k.generate_key(crypto.TYPE_RSA, 1024) + + # Create a self-signed cert. + cert = crypto.X509() + cert.get_subject().CN = hostname + + # Add the subjectAltNames (if necessary). + if san_list is not None: + cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))]) + + cert.set_serial_number(1000) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(expires) + cert.set_issuer(cert.get_subject()) + + cert.set_pubkey(k) + cert.sign(k, 'sha1') + + # Dump the certificate and private key in PEM format. + cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) + key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k) + + return (cert_data, key_data) + + def test_load_certificate(self): + # Try loading an invalid certificate. + with self.assertRaisesRegexp(CertInvalidException, 'no start line'): + load_certificate('someinvalidcontents') + + # Load a valid certificate. + (public_key_data, _) = self._generate_cert() + + cert = load_certificate(public_key_data) + self.assertFalse(cert.expired) + self.assertEquals(set(['somehostname']), cert.names) + self.assertTrue(cert.matches_name('somehostname')) + + def test_expired_certificate(self): + (public_key_data, _) = self._generate_cert(expires=-100) + + cert = load_certificate(public_key_data) + self.assertTrue(cert.expired) + + def test_hostnames(self): + (public_key_data, _) = self._generate_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz']) + cert = load_certificate(public_key_data) + self.assertEquals(set(['foo', 'bar', 'baz']), cert.names) + + for name in cert.names: + self.assertTrue(cert.matches_name(name)) + + def test_nondns_hostnames(self): + (public_key_data, _) = self._generate_cert(hostname='foo', san_list=['URI:yarg']) + cert = load_certificate(public_key_data) + self.assertEquals(set(['foo']), cert.names) + + def test_validate_private_key(self): + (public_key_data, private_key_data) = self._generate_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write(private_key_data) + private_key.seek(0) + + cert = load_certificate(public_key_data) + cert.validate_private_key(private_key.name) + + def test_invalid_private_key(self): + (public_key_data, _) = self._generate_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write('somerandomdata') + private_key.seek(0) + + cert = load_certificate(public_key_data) + with self.assertRaisesRegexp(KeyInvalidException, 'no start line'): + cert.validate_private_key(private_key.name) + + def test_mismatch_private_key(self): + (public_key_data, _) = self._generate_cert() + (_, private_key_data) = self._generate_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write(private_key_data) + private_key.seek(0) + + cert = load_certificate(public_key_data) + with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'): + cert.validate_private_key(private_key.name) + + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_validate_config.py b/test/test_validate_config.py index 27cca658d..2ad5cf2a8 100644 --- a/test/test_validate_config.py +++ b/test/test_validate_config.py @@ -143,8 +143,6 @@ class TestValidateConfig(unittest.TestCase): 'PREFERRED_URL_SCHEME': 'https', }) - # TODO(jschorr): Add SSL verification tests once file lookup is fixed. - def test_validate_keystone(self): with self.assertRaisesRegexp(ConfigValidationException, 'Verification of superuser someuser failed'): diff --git a/util/config/provider/baseprovider.py b/util/config/provider/baseprovider.py index c833c2323..e5af29fa4 100644 --- a/util/config/provider/baseprovider.py +++ b/util/config/provider/baseprovider.py @@ -88,6 +88,16 @@ class BaseProvider(object): """ Writes the given contents to the config override volumne, with the given filename. """ raise NotImplementedError + def remove_volume_file(self, filename): + """ Removes the config override volume file with the given filename. """ + raise NotImplementedError + + def list_volume_directory(self, path): + """ Returns a list of strings representing the names of the files found in the config override + directory under the given path. If the path doesn't exist, returns None. + """ + raise NotImplementedError + def save_volume_file(self, filename, flask_file): """ Saves the given flask file to the config override volume, with the given filename. diff --git a/util/config/provider/fileprovider.py b/util/config/provider/fileprovider.py index 099c5c3fd..f59c495f1 100644 --- a/util/config/provider/fileprovider.py +++ b/util/config/provider/fileprovider.py @@ -59,6 +59,17 @@ class FileConfigProvider(BaseProvider): return filepath + def remove_volume_file(self, filename): + filepath = os.path.join(self.config_volume, filename) + os.remove(filepath) + + def list_volume_directory(self, path): + dirpath = os.path.join(self.config_volume, path) + if not os.path.exists(dirpath): + return None + + return os.listdir(dirpath) + def save_volume_file(self, filename, flask_file): filepath = os.path.join(self.config_volume, filename) try: diff --git a/util/config/provider/k8sprovider.py b/util/config/provider/k8sprovider.py index 334cf4044..0feb6ca60 100644 --- a/util/config/provider/k8sprovider.py +++ b/util/config/provider/k8sprovider.py @@ -55,53 +55,56 @@ class KubernetesConfigProvider(FileConfigProvider): except IOError as ioe: raise CannotWriteConfigException(str(ioe)) - def save_volume_file(self, filename, flask_file): - filepath = super(KubernetesConfigProvider, self).save_volume_file(filename, flask_file) + def remove_volume_file(self, filename): + super(KubernetesConfigProvider, self).remove_volume_file(filename) try: - with open(filepath, 'r') as f: - self._update_secret_file(filename, f.read()) + self._update_secret_file(filename, None) except IOError as ioe: raise CannotWriteConfigException(str(ioe)) + def save_volume_file(self, filename, flask_file): + filepath = super(KubernetesConfigProvider, self).save_volume_file(filename, flask_file) + with open(filepath, 'r') as f: + self.write_volume_file(filename, f.read()) + def _assert_success(self, response): if response.status_code != 200: logger.error('Kubernetes API call failed with response: %s => %s', response.status_code, response.text) raise CannotWriteConfigException('Kubernetes API call failed: %s' % response.text) - def _update_secret_file(self, filename, value): + def _update_secret_file(self, filename, value=None): # Check first that the namespace for Quay Enterprise exists. If it does not, report that # as an error, as it seems to be a common issue. namespace_url = 'namespaces/%s' % (QE_NAMESPACE) response = self._execute_k8s_api('GET', namespace_url) - if response.status_code != 200: + if response.status_code // 100 != 2: msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE raise CannotWriteConfigException(msg) - # Save the secret to the namespace. - secret_data = {} - secret_data[filename] = base64.b64encode(value) - - data = { - "kind": "Secret", - "apiVersion": "v1", - "metadata": { - "name": QE_CONFIG_SECRET - }, - "data": secret_data - } - + # Check if the secret exists. If not, then we create an empty secret and then update the file + # inside. secret_url = 'namespaces/%s/secrets/%s' % (QE_NAMESPACE, QE_CONFIG_SECRET) secret = self._lookup_secret() - if not secret: - self._assert_success(self._execute_k8s_api('POST', secret_url, data)) - return + if secret is None: + self._assert_success(self._execute_k8s_api('POST', secret_url, { + "kind": "Secret", + "apiVersion": "v1", + "metadata": { + "name": QE_CONFIG_SECRET + }, + "data": {} + })) - if not 'data' in secret: - secret['data'] = {} + # Update the secret to reflect the file change. + secret['data'] = secret.get('data', {}) + + if value is not None: + secret['data'][filename] = base64.b64encode(value) + else: + secret['data'].pop(filename) - secret['data'][filename] = base64.b64encode(value) self._assert_success(self._execute_k8s_api('PUT', secret_url, secret)) diff --git a/util/config/validator.py b/util/config/validator.py index 5c9155936..60e44b443 100644 --- a/util/config/validator.py +++ b/util/config/validator.py @@ -3,16 +3,14 @@ import subprocess import time from StringIO import StringIO -from fnmatch import fnmatch +from hashlib import sha1 -import OpenSSL import ldap import peewee import redis from flask import Flask from flask_mail import Mail, Message -from hashlib import sha1 from app import app, config_provider, get_app_url, OVERRIDE_CONFIG_DIRECTORY from auth.auth_context import get_authenticated_user @@ -28,6 +26,7 @@ from util.config.oauth import GoogleOAuthConfig, GithubOAuthConfig, GitLabOAuthC from util.secscan.api import SecurityScannerAPI from util.registry.torrent import torrent_jwt from util.security.signing import SIGNING_ENGINES +from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException logger = logging.getLogger(__name__) @@ -54,7 +53,7 @@ def get_storage_providers(config): try: for name, parameters in storage_config.items(): drivers[name] = (parameters[0], get_storage_driver(None, None, None, parameters)) - except TypeError as te: + except TypeError: logger.exception('Missing required storage configuration provider') raise ConfigValidationException('Missing required parameter(s) for storage %s' % name) @@ -254,25 +253,35 @@ def _validate_ssl(config, user_obj, _): return # Skip if externally terminated. - if config.get('EXTERNAL_TLS_TERMINATION', False) == True: + if config.get('EXTERNAL_TLS_TERMINATION', False) is True: return + # Verify that we have all the required SSL files. for filename in SSL_FILENAMES: if not config_provider.volume_file_exists(filename): raise ConfigValidationException('Missing required SSL file: %s' % filename) + # Read the contents of the SSL certificate. with config_provider.get_volume_file(SSL_FILENAMES[0]) as f: cert_contents = f.read() # Validate the certificate. try: - cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents) - except: - raise ConfigValidationException('Could not parse certificate file. Is it a valid PEM certificate?') + certificate = load_certificate(cert_contents) + except CertInvalidException as cie: + raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message) - if cert.has_expired(): + # Verify the certificate has not expired. + if certificate.expired: raise ConfigValidationException('The specified SSL certificate has expired.') + # Verify the hostname matches the name in the certificate. + if not certificate.matches_name(config['SERVER_HOSTNAME']): + msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' % + (', '.join(list(certificate.names)), config['SERVER_HOSTNAME'])) + raise ConfigValidationException(msg) + + # Verify the private key against the certificate. private_key_path = None with config_provider.get_volume_file(SSL_FILENAMES[1]) as f: private_key_path = f.name @@ -281,44 +290,10 @@ def _validate_ssl(config, user_obj, _): # Only in testing. return - # Validate the private key with the certificate. - context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD) - context.use_certificate(cert) - try: - context.use_privatekey_file(private_key_path) - except: - raise ConfigValidationException('Could not parse key file. Is it a valid PEM private key?') - - try: - context.check_privatekey() - except OpenSSL.SSL.Error as e: - raise ConfigValidationException('SSL key failed to validate: %s' % str(e)) - - # Verify the hostname matches the name in the certificate. - common_name = cert.get_subject().commonName - if common_name is None: - raise ConfigValidationException('Missing CommonName (CN) from SSL certificate') - - # Build the list of allowed host patterns. - hosts = set([common_name]) - - # Find the DNS extension, if any. - for i in range(0, cert.get_extension_count()): - ext = cert.get_extension(i) - if ext.get_short_name() == 'subjectAltName': - value = str(ext) - hosts.update([host.strip()[4:] for host in value.split(',')]) - - # Check each host. - for host in hosts: - if fnmatch(config['SERVER_HOSTNAME'], host): - return - - msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' % - (', '.join(list(hosts)), config['SERVER_HOSTNAME'])) - raise ConfigValidationException(msg) - + certificate.validate_private_key(private_key_path) + except KeyInvalidException as kie: + raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message) def _validate_ldap(config, user_obj, password): diff --git a/util/security/ssl.py b/util/security/ssl.py new file mode 100644 index 000000000..7f0534c9a --- /dev/null +++ b/util/security/ssl.py @@ -0,0 +1,81 @@ +from fnmatch import fnmatch + +import OpenSSL + +class CertInvalidException(Exception): + """ Exception raised when a certificate could not be parsed/loaded. """ + pass + +class KeyInvalidException(Exception): + """ Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """ + pass + + +def load_certificate(cert_contents): + """ Loads the certificate from the given contents and returns it or raises a CertInvalidException + on failure. + """ + try: + cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents) + return SSLCertificate(cert) + except OpenSSL.crypto.Error as ex: + raise CertInvalidException(ex.message[0][2]) + + +_SUBJECT_ALT_NAME = 'subjectAltName' + +class SSLCertificate(object): + """ Helper class for easier working with SSL certificates. """ + def __init__(self, openssl_cert): + self.openssl_cert = openssl_cert + + def validate_private_key(self, private_key_path): + """ Validates that the private key found at the given file path applies to this certificate. + Raises a KeyInvalidException on failure. + """ + context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD) + context.use_certificate(self.openssl_cert) + + try: + context.use_privatekey_file(private_key_path) + context.check_privatekey() + except OpenSSL.SSL.Error as ex: + raise KeyInvalidException(ex.message[0][2]) + + def matches_name(self, check_name): + """ Returns true if this SSL certificate matches the given DNS hostname. """ + for dns_name in self.names: + if fnmatch(dns_name, check_name): + return True + + return False + + @property + def expired(self): + """ Returns whether the SSL certificate has expired. """ + return self.openssl_cert.has_expired() + + @property + def common_name(self): + """ Returns the defined common name for the certificate, if any. """ + return self.openssl_cert.get_subject().commonName + + @property + def names(self): + """ Returns all the DNS named to which the certificate applies. May be empty. """ + dns_names = set() + common_name = self.common_name + if common_name is not None: + dns_names.add(common_name) + + # Find the DNS extension, if any. + for i in range(0, self.openssl_cert.get_extension_count()): + ext = self.openssl_cert.get_extension(i) + if ext.get_short_name() == _SUBJECT_ALT_NAME: + value = str(ext) + for san_name in value.split(','): + san_name_trimmed = san_name.strip() + if san_name_trimmed.startswith('DNS:'): + dns_names.add(san_name_trimmed[4:]) + + return dns_names