Merge pull request #2271 from coreos-inc/custom-certs
Better handling and testing of custom certificates
This commit is contained in:
commit
a6ae770b77
7 changed files with 256 additions and 73 deletions
105
test/test_ssl_util.py
Normal file
105
test/test_ssl_util.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
import unittest
|
||||
|
||||
from tempfile import NamedTemporaryFile
|
||||
from OpenSSL import crypto
|
||||
|
||||
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
|
||||
|
||||
class TestSSLCertificate(unittest.TestCase):
|
||||
def _generate_cert(self, hostname='somehostname', san_list=None, expires=1000000):
|
||||
# Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/
|
||||
# Create a key pair.
|
||||
k = crypto.PKey()
|
||||
k.generate_key(crypto.TYPE_RSA, 1024)
|
||||
|
||||
# Create a self-signed cert.
|
||||
cert = crypto.X509()
|
||||
cert.get_subject().CN = hostname
|
||||
|
||||
# Add the subjectAltNames (if necessary).
|
||||
if san_list is not None:
|
||||
cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))])
|
||||
|
||||
cert.set_serial_number(1000)
|
||||
cert.gmtime_adj_notBefore(0)
|
||||
cert.gmtime_adj_notAfter(expires)
|
||||
cert.set_issuer(cert.get_subject())
|
||||
|
||||
cert.set_pubkey(k)
|
||||
cert.sign(k, 'sha1')
|
||||
|
||||
# Dump the certificate and private key in PEM format.
|
||||
cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
|
||||
key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
|
||||
|
||||
return (cert_data, key_data)
|
||||
|
||||
def test_load_certificate(self):
|
||||
# Try loading an invalid certificate.
|
||||
with self.assertRaisesRegexp(CertInvalidException, 'no start line'):
|
||||
load_certificate('someinvalidcontents')
|
||||
|
||||
# Load a valid certificate.
|
||||
(public_key_data, _) = self._generate_cert()
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
self.assertFalse(cert.expired)
|
||||
self.assertEquals(set(['somehostname']), cert.names)
|
||||
self.assertTrue(cert.matches_name('somehostname'))
|
||||
|
||||
def test_expired_certificate(self):
|
||||
(public_key_data, _) = self._generate_cert(expires=-100)
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
self.assertTrue(cert.expired)
|
||||
|
||||
def test_hostnames(self):
|
||||
(public_key_data, _) = self._generate_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz'])
|
||||
cert = load_certificate(public_key_data)
|
||||
self.assertEquals(set(['foo', 'bar', 'baz']), cert.names)
|
||||
|
||||
for name in cert.names:
|
||||
self.assertTrue(cert.matches_name(name))
|
||||
|
||||
def test_nondns_hostnames(self):
|
||||
(public_key_data, _) = self._generate_cert(hostname='foo', san_list=['URI:yarg'])
|
||||
cert = load_certificate(public_key_data)
|
||||
self.assertEquals(set(['foo']), cert.names)
|
||||
|
||||
def test_validate_private_key(self):
|
||||
(public_key_data, private_key_data) = self._generate_cert()
|
||||
|
||||
private_key = NamedTemporaryFile(delete=True)
|
||||
private_key.write(private_key_data)
|
||||
private_key.seek(0)
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
cert.validate_private_key(private_key.name)
|
||||
|
||||
def test_invalid_private_key(self):
|
||||
(public_key_data, _) = self._generate_cert()
|
||||
|
||||
private_key = NamedTemporaryFile(delete=True)
|
||||
private_key.write('somerandomdata')
|
||||
private_key.seek(0)
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
with self.assertRaisesRegexp(KeyInvalidException, 'no start line'):
|
||||
cert.validate_private_key(private_key.name)
|
||||
|
||||
def test_mismatch_private_key(self):
|
||||
(public_key_data, _) = self._generate_cert()
|
||||
(_, private_key_data) = self._generate_cert()
|
||||
|
||||
private_key = NamedTemporaryFile(delete=True)
|
||||
private_key.write(private_key_data)
|
||||
private_key.seek(0)
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'):
|
||||
cert.validate_private_key(private_key.name)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
|
@ -143,8 +143,6 @@ class TestValidateConfig(unittest.TestCase):
|
|||
'PREFERRED_URL_SCHEME': 'https',
|
||||
})
|
||||
|
||||
# TODO(jschorr): Add SSL verification tests once file lookup is fixed.
|
||||
|
||||
def test_validate_keystone(self):
|
||||
with self.assertRaisesRegexp(ConfigValidationException,
|
||||
'Verification of superuser someuser failed'):
|
||||
|
|
|
@ -88,6 +88,16 @@ class BaseProvider(object):
|
|||
""" Writes the given contents to the config override volumne, with the given filename. """
|
||||
raise NotImplementedError
|
||||
|
||||
def remove_volume_file(self, filename):
|
||||
""" Removes the config override volume file with the given filename. """
|
||||
raise NotImplementedError
|
||||
|
||||
def list_volume_directory(self, path):
|
||||
""" Returns a list of strings representing the names of the files found in the config override
|
||||
directory under the given path. If the path doesn't exist, returns None.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_volume_file(self, filename, flask_file):
|
||||
""" Saves the given flask file to the config override volume, with the given
|
||||
filename.
|
||||
|
|
|
@ -59,6 +59,17 @@ class FileConfigProvider(BaseProvider):
|
|||
|
||||
return filepath
|
||||
|
||||
def remove_volume_file(self, filename):
|
||||
filepath = os.path.join(self.config_volume, filename)
|
||||
os.remove(filepath)
|
||||
|
||||
def list_volume_directory(self, path):
|
||||
dirpath = os.path.join(self.config_volume, path)
|
||||
if not os.path.exists(dirpath):
|
||||
return None
|
||||
|
||||
return os.listdir(dirpath)
|
||||
|
||||
def save_volume_file(self, filename, flask_file):
|
||||
filepath = os.path.join(self.config_volume, filename)
|
||||
try:
|
||||
|
|
|
@ -55,53 +55,56 @@ class KubernetesConfigProvider(FileConfigProvider):
|
|||
except IOError as ioe:
|
||||
raise CannotWriteConfigException(str(ioe))
|
||||
|
||||
def save_volume_file(self, filename, flask_file):
|
||||
filepath = super(KubernetesConfigProvider, self).save_volume_file(filename, flask_file)
|
||||
def remove_volume_file(self, filename):
|
||||
super(KubernetesConfigProvider, self).remove_volume_file(filename)
|
||||
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
self._update_secret_file(filename, f.read())
|
||||
self._update_secret_file(filename, None)
|
||||
except IOError as ioe:
|
||||
raise CannotWriteConfigException(str(ioe))
|
||||
|
||||
def save_volume_file(self, filename, flask_file):
|
||||
filepath = super(KubernetesConfigProvider, self).save_volume_file(filename, flask_file)
|
||||
with open(filepath, 'r') as f:
|
||||
self.write_volume_file(filename, f.read())
|
||||
|
||||
def _assert_success(self, response):
|
||||
if response.status_code != 200:
|
||||
logger.error('Kubernetes API call failed with response: %s => %s', response.status_code,
|
||||
response.text)
|
||||
raise CannotWriteConfigException('Kubernetes API call failed: %s' % response.text)
|
||||
|
||||
def _update_secret_file(self, filename, value):
|
||||
def _update_secret_file(self, filename, value=None):
|
||||
# Check first that the namespace for Quay Enterprise exists. If it does not, report that
|
||||
# as an error, as it seems to be a common issue.
|
||||
namespace_url = 'namespaces/%s' % (QE_NAMESPACE)
|
||||
response = self._execute_k8s_api('GET', namespace_url)
|
||||
if response.status_code != 200:
|
||||
if response.status_code // 100 != 2:
|
||||
msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE
|
||||
raise CannotWriteConfigException(msg)
|
||||
|
||||
# Save the secret to the namespace.
|
||||
secret_data = {}
|
||||
secret_data[filename] = base64.b64encode(value)
|
||||
|
||||
data = {
|
||||
"kind": "Secret",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {
|
||||
"name": QE_CONFIG_SECRET
|
||||
},
|
||||
"data": secret_data
|
||||
}
|
||||
|
||||
# Check if the secret exists. If not, then we create an empty secret and then update the file
|
||||
# inside.
|
||||
secret_url = 'namespaces/%s/secrets/%s' % (QE_NAMESPACE, QE_CONFIG_SECRET)
|
||||
secret = self._lookup_secret()
|
||||
if not secret:
|
||||
self._assert_success(self._execute_k8s_api('POST', secret_url, data))
|
||||
return
|
||||
if secret is None:
|
||||
self._assert_success(self._execute_k8s_api('POST', secret_url, {
|
||||
"kind": "Secret",
|
||||
"apiVersion": "v1",
|
||||
"metadata": {
|
||||
"name": QE_CONFIG_SECRET
|
||||
},
|
||||
"data": {}
|
||||
}))
|
||||
|
||||
if not 'data' in secret:
|
||||
secret['data'] = {}
|
||||
# Update the secret to reflect the file change.
|
||||
secret['data'] = secret.get('data', {})
|
||||
|
||||
if value is not None:
|
||||
secret['data'][filename] = base64.b64encode(value)
|
||||
else:
|
||||
secret['data'].pop(filename)
|
||||
|
||||
secret['data'][filename] = base64.b64encode(value)
|
||||
self._assert_success(self._execute_k8s_api('PUT', secret_url, secret))
|
||||
|
||||
|
||||
|
|
|
@ -3,16 +3,14 @@ import subprocess
|
|||
import time
|
||||
|
||||
from StringIO import StringIO
|
||||
from fnmatch import fnmatch
|
||||
from hashlib import sha1
|
||||
|
||||
import OpenSSL
|
||||
import ldap
|
||||
import peewee
|
||||
import redis
|
||||
|
||||
from flask import Flask
|
||||
from flask_mail import Mail, Message
|
||||
from hashlib import sha1
|
||||
|
||||
from app import app, config_provider, get_app_url, OVERRIDE_CONFIG_DIRECTORY
|
||||
from auth.auth_context import get_authenticated_user
|
||||
|
@ -28,6 +26,7 @@ from util.config.oauth import GoogleOAuthConfig, GithubOAuthConfig, GitLabOAuthC
|
|||
from util.secscan.api import SecurityScannerAPI
|
||||
from util.registry.torrent import torrent_jwt
|
||||
from util.security.signing import SIGNING_ENGINES
|
||||
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -54,7 +53,7 @@ def get_storage_providers(config):
|
|||
try:
|
||||
for name, parameters in storage_config.items():
|
||||
drivers[name] = (parameters[0], get_storage_driver(None, None, None, parameters))
|
||||
except TypeError as te:
|
||||
except TypeError:
|
||||
logger.exception('Missing required storage configuration provider')
|
||||
raise ConfigValidationException('Missing required parameter(s) for storage %s' % name)
|
||||
|
||||
|
@ -254,25 +253,35 @@ def _validate_ssl(config, user_obj, _):
|
|||
return
|
||||
|
||||
# Skip if externally terminated.
|
||||
if config.get('EXTERNAL_TLS_TERMINATION', False) == True:
|
||||
if config.get('EXTERNAL_TLS_TERMINATION', False) is True:
|
||||
return
|
||||
|
||||
# Verify that we have all the required SSL files.
|
||||
for filename in SSL_FILENAMES:
|
||||
if not config_provider.volume_file_exists(filename):
|
||||
raise ConfigValidationException('Missing required SSL file: %s' % filename)
|
||||
|
||||
# Read the contents of the SSL certificate.
|
||||
with config_provider.get_volume_file(SSL_FILENAMES[0]) as f:
|
||||
cert_contents = f.read()
|
||||
|
||||
# Validate the certificate.
|
||||
try:
|
||||
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
|
||||
except:
|
||||
raise ConfigValidationException('Could not parse certificate file. Is it a valid PEM certificate?')
|
||||
certificate = load_certificate(cert_contents)
|
||||
except CertInvalidException as cie:
|
||||
raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message)
|
||||
|
||||
if cert.has_expired():
|
||||
# Verify the certificate has not expired.
|
||||
if certificate.expired:
|
||||
raise ConfigValidationException('The specified SSL certificate has expired.')
|
||||
|
||||
# Verify the hostname matches the name in the certificate.
|
||||
if not certificate.matches_name(config['SERVER_HOSTNAME']):
|
||||
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
|
||||
(', '.join(list(certificate.names)), config['SERVER_HOSTNAME']))
|
||||
raise ConfigValidationException(msg)
|
||||
|
||||
# Verify the private key against the certificate.
|
||||
private_key_path = None
|
||||
with config_provider.get_volume_file(SSL_FILENAMES[1]) as f:
|
||||
private_key_path = f.name
|
||||
|
@ -281,44 +290,10 @@ def _validate_ssl(config, user_obj, _):
|
|||
# Only in testing.
|
||||
return
|
||||
|
||||
# Validate the private key with the certificate.
|
||||
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
|
||||
context.use_certificate(cert)
|
||||
|
||||
try:
|
||||
context.use_privatekey_file(private_key_path)
|
||||
except:
|
||||
raise ConfigValidationException('Could not parse key file. Is it a valid PEM private key?')
|
||||
|
||||
try:
|
||||
context.check_privatekey()
|
||||
except OpenSSL.SSL.Error as e:
|
||||
raise ConfigValidationException('SSL key failed to validate: %s' % str(e))
|
||||
|
||||
# Verify the hostname matches the name in the certificate.
|
||||
common_name = cert.get_subject().commonName
|
||||
if common_name is None:
|
||||
raise ConfigValidationException('Missing CommonName (CN) from SSL certificate')
|
||||
|
||||
# Build the list of allowed host patterns.
|
||||
hosts = set([common_name])
|
||||
|
||||
# Find the DNS extension, if any.
|
||||
for i in range(0, cert.get_extension_count()):
|
||||
ext = cert.get_extension(i)
|
||||
if ext.get_short_name() == 'subjectAltName':
|
||||
value = str(ext)
|
||||
hosts.update([host.strip()[4:] for host in value.split(',')])
|
||||
|
||||
# Check each host.
|
||||
for host in hosts:
|
||||
if fnmatch(config['SERVER_HOSTNAME'], host):
|
||||
return
|
||||
|
||||
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
|
||||
(', '.join(list(hosts)), config['SERVER_HOSTNAME']))
|
||||
raise ConfigValidationException(msg)
|
||||
|
||||
certificate.validate_private_key(private_key_path)
|
||||
except KeyInvalidException as kie:
|
||||
raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message)
|
||||
|
||||
|
||||
def _validate_ldap(config, user_obj, password):
|
||||
|
|
81
util/security/ssl.py
Normal file
81
util/security/ssl.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
from fnmatch import fnmatch
|
||||
|
||||
import OpenSSL
|
||||
|
||||
class CertInvalidException(Exception):
|
||||
""" Exception raised when a certificate could not be parsed/loaded. """
|
||||
pass
|
||||
|
||||
class KeyInvalidException(Exception):
|
||||
""" Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
|
||||
pass
|
||||
|
||||
|
||||
def load_certificate(cert_contents):
|
||||
""" Loads the certificate from the given contents and returns it or raises a CertInvalidException
|
||||
on failure.
|
||||
"""
|
||||
try:
|
||||
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
|
||||
return SSLCertificate(cert)
|
||||
except OpenSSL.crypto.Error as ex:
|
||||
raise CertInvalidException(ex.message[0][2])
|
||||
|
||||
|
||||
_SUBJECT_ALT_NAME = 'subjectAltName'
|
||||
|
||||
class SSLCertificate(object):
|
||||
""" Helper class for easier working with SSL certificates. """
|
||||
def __init__(self, openssl_cert):
|
||||
self.openssl_cert = openssl_cert
|
||||
|
||||
def validate_private_key(self, private_key_path):
|
||||
""" Validates that the private key found at the given file path applies to this certificate.
|
||||
Raises a KeyInvalidException on failure.
|
||||
"""
|
||||
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
|
||||
context.use_certificate(self.openssl_cert)
|
||||
|
||||
try:
|
||||
context.use_privatekey_file(private_key_path)
|
||||
context.check_privatekey()
|
||||
except OpenSSL.SSL.Error as ex:
|
||||
raise KeyInvalidException(ex.message[0][2])
|
||||
|
||||
def matches_name(self, check_name):
|
||||
""" Returns true if this SSL certificate matches the given DNS hostname. """
|
||||
for dns_name in self.names:
|
||||
if fnmatch(dns_name, check_name):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
""" Returns whether the SSL certificate has expired. """
|
||||
return self.openssl_cert.has_expired()
|
||||
|
||||
@property
|
||||
def common_name(self):
|
||||
""" Returns the defined common name for the certificate, if any. """
|
||||
return self.openssl_cert.get_subject().commonName
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
""" Returns all the DNS named to which the certificate applies. May be empty. """
|
||||
dns_names = set()
|
||||
common_name = self.common_name
|
||||
if common_name is not None:
|
||||
dns_names.add(common_name)
|
||||
|
||||
# Find the DNS extension, if any.
|
||||
for i in range(0, self.openssl_cert.get_extension_count()):
|
||||
ext = self.openssl_cert.get_extension(i)
|
||||
if ext.get_short_name() == _SUBJECT_ALT_NAME:
|
||||
value = str(ext)
|
||||
for san_name in value.split(','):
|
||||
san_name_trimmed = san_name.strip()
|
||||
if san_name_trimmed.startswith('DNS:'):
|
||||
dns_names.add(san_name_trimmed[4:])
|
||||
|
||||
return dns_names
|
Reference in a new issue