Merge pull request #2271 from coreos-inc/custom-certs

Better handling and testing of custom certificates
This commit is contained in:
josephschorr 2017-01-10 18:13:44 -05:00 committed by GitHub
commit a6ae770b77
7 changed files with 256 additions and 73 deletions

105
test/test_ssl_util.py Normal file
View file

@ -0,0 +1,105 @@
import unittest
from tempfile import NamedTemporaryFile
from OpenSSL import crypto
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
class TestSSLCertificate(unittest.TestCase):
def _generate_cert(self, hostname='somehostname', san_list=None, expires=1000000):
# Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/
# Create a key pair.
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 1024)
# Create a self-signed cert.
cert = crypto.X509()
cert.get_subject().CN = hostname
# Add the subjectAltNames (if necessary).
if san_list is not None:
cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))])
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(expires)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(k)
cert.sign(k, 'sha1')
# Dump the certificate and private key in PEM format.
cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
return (cert_data, key_data)
def test_load_certificate(self):
# Try loading an invalid certificate.
with self.assertRaisesRegexp(CertInvalidException, 'no start line'):
load_certificate('someinvalidcontents')
# Load a valid certificate.
(public_key_data, _) = self._generate_cert()
cert = load_certificate(public_key_data)
self.assertFalse(cert.expired)
self.assertEquals(set(['somehostname']), cert.names)
self.assertTrue(cert.matches_name('somehostname'))
def test_expired_certificate(self):
(public_key_data, _) = self._generate_cert(expires=-100)
cert = load_certificate(public_key_data)
self.assertTrue(cert.expired)
def test_hostnames(self):
(public_key_data, _) = self._generate_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz'])
cert = load_certificate(public_key_data)
self.assertEquals(set(['foo', 'bar', 'baz']), cert.names)
for name in cert.names:
self.assertTrue(cert.matches_name(name))
def test_nondns_hostnames(self):
(public_key_data, _) = self._generate_cert(hostname='foo', san_list=['URI:yarg'])
cert = load_certificate(public_key_data)
self.assertEquals(set(['foo']), cert.names)
def test_validate_private_key(self):
(public_key_data, private_key_data) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write(private_key_data)
private_key.seek(0)
cert = load_certificate(public_key_data)
cert.validate_private_key(private_key.name)
def test_invalid_private_key(self):
(public_key_data, _) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write('somerandomdata')
private_key.seek(0)
cert = load_certificate(public_key_data)
with self.assertRaisesRegexp(KeyInvalidException, 'no start line'):
cert.validate_private_key(private_key.name)
def test_mismatch_private_key(self):
(public_key_data, _) = self._generate_cert()
(_, private_key_data) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write(private_key_data)
private_key.seek(0)
cert = load_certificate(public_key_data)
with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'):
cert.validate_private_key(private_key.name)
if __name__ == '__main__':
unittest.main()

View file

@ -143,8 +143,6 @@ class TestValidateConfig(unittest.TestCase):
'PREFERRED_URL_SCHEME': 'https',
})
# TODO(jschorr): Add SSL verification tests once file lookup is fixed.
def test_validate_keystone(self):
with self.assertRaisesRegexp(ConfigValidationException,
'Verification of superuser someuser failed'):

View file

@ -88,6 +88,16 @@ class BaseProvider(object):
""" Writes the given contents to the config override volumne, with the given filename. """
raise NotImplementedError
def remove_volume_file(self, filename):
""" Removes the config override volume file with the given filename. """
raise NotImplementedError
def list_volume_directory(self, path):
""" Returns a list of strings representing the names of the files found in the config override
directory under the given path. If the path doesn't exist, returns None.
"""
raise NotImplementedError
def save_volume_file(self, filename, flask_file):
""" Saves the given flask file to the config override volume, with the given
filename.

View file

@ -59,6 +59,17 @@ class FileConfigProvider(BaseProvider):
return filepath
def remove_volume_file(self, filename):
filepath = os.path.join(self.config_volume, filename)
os.remove(filepath)
def list_volume_directory(self, path):
dirpath = os.path.join(self.config_volume, path)
if not os.path.exists(dirpath):
return None
return os.listdir(dirpath)
def save_volume_file(self, filename, flask_file):
filepath = os.path.join(self.config_volume, filename)
try:

View file

@ -55,53 +55,56 @@ class KubernetesConfigProvider(FileConfigProvider):
except IOError as ioe:
raise CannotWriteConfigException(str(ioe))
def save_volume_file(self, filename, flask_file):
filepath = super(KubernetesConfigProvider, self).save_volume_file(filename, flask_file)
def remove_volume_file(self, filename):
super(KubernetesConfigProvider, self).remove_volume_file(filename)
try:
with open(filepath, 'r') as f:
self._update_secret_file(filename, f.read())
self._update_secret_file(filename, None)
except IOError as ioe:
raise CannotWriteConfigException(str(ioe))
def save_volume_file(self, filename, flask_file):
filepath = super(KubernetesConfigProvider, self).save_volume_file(filename, flask_file)
with open(filepath, 'r') as f:
self.write_volume_file(filename, f.read())
def _assert_success(self, response):
if response.status_code != 200:
logger.error('Kubernetes API call failed with response: %s => %s', response.status_code,
response.text)
raise CannotWriteConfigException('Kubernetes API call failed: %s' % response.text)
def _update_secret_file(self, filename, value):
def _update_secret_file(self, filename, value=None):
# Check first that the namespace for Quay Enterprise exists. If it does not, report that
# as an error, as it seems to be a common issue.
namespace_url = 'namespaces/%s' % (QE_NAMESPACE)
response = self._execute_k8s_api('GET', namespace_url)
if response.status_code != 200:
if response.status_code // 100 != 2:
msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE
raise CannotWriteConfigException(msg)
# Save the secret to the namespace.
secret_data = {}
secret_data[filename] = base64.b64encode(value)
data = {
"kind": "Secret",
"apiVersion": "v1",
"metadata": {
"name": QE_CONFIG_SECRET
},
"data": secret_data
}
# Check if the secret exists. If not, then we create an empty secret and then update the file
# inside.
secret_url = 'namespaces/%s/secrets/%s' % (QE_NAMESPACE, QE_CONFIG_SECRET)
secret = self._lookup_secret()
if not secret:
self._assert_success(self._execute_k8s_api('POST', secret_url, data))
return
if secret is None:
self._assert_success(self._execute_k8s_api('POST', secret_url, {
"kind": "Secret",
"apiVersion": "v1",
"metadata": {
"name": QE_CONFIG_SECRET
},
"data": {}
}))
if not 'data' in secret:
secret['data'] = {}
# Update the secret to reflect the file change.
secret['data'] = secret.get('data', {})
if value is not None:
secret['data'][filename] = base64.b64encode(value)
else:
secret['data'].pop(filename)
secret['data'][filename] = base64.b64encode(value)
self._assert_success(self._execute_k8s_api('PUT', secret_url, secret))

View file

@ -3,16 +3,14 @@ import subprocess
import time
from StringIO import StringIO
from fnmatch import fnmatch
from hashlib import sha1
import OpenSSL
import ldap
import peewee
import redis
from flask import Flask
from flask_mail import Mail, Message
from hashlib import sha1
from app import app, config_provider, get_app_url, OVERRIDE_CONFIG_DIRECTORY
from auth.auth_context import get_authenticated_user
@ -28,6 +26,7 @@ from util.config.oauth import GoogleOAuthConfig, GithubOAuthConfig, GitLabOAuthC
from util.secscan.api import SecurityScannerAPI
from util.registry.torrent import torrent_jwt
from util.security.signing import SIGNING_ENGINES
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
logger = logging.getLogger(__name__)
@ -54,7 +53,7 @@ def get_storage_providers(config):
try:
for name, parameters in storage_config.items():
drivers[name] = (parameters[0], get_storage_driver(None, None, None, parameters))
except TypeError as te:
except TypeError:
logger.exception('Missing required storage configuration provider')
raise ConfigValidationException('Missing required parameter(s) for storage %s' % name)
@ -254,25 +253,35 @@ def _validate_ssl(config, user_obj, _):
return
# Skip if externally terminated.
if config.get('EXTERNAL_TLS_TERMINATION', False) == True:
if config.get('EXTERNAL_TLS_TERMINATION', False) is True:
return
# Verify that we have all the required SSL files.
for filename in SSL_FILENAMES:
if not config_provider.volume_file_exists(filename):
raise ConfigValidationException('Missing required SSL file: %s' % filename)
# Read the contents of the SSL certificate.
with config_provider.get_volume_file(SSL_FILENAMES[0]) as f:
cert_contents = f.read()
# Validate the certificate.
try:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
except:
raise ConfigValidationException('Could not parse certificate file. Is it a valid PEM certificate?')
certificate = load_certificate(cert_contents)
except CertInvalidException as cie:
raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message)
if cert.has_expired():
# Verify the certificate has not expired.
if certificate.expired:
raise ConfigValidationException('The specified SSL certificate has expired.')
# Verify the hostname matches the name in the certificate.
if not certificate.matches_name(config['SERVER_HOSTNAME']):
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
(', '.join(list(certificate.names)), config['SERVER_HOSTNAME']))
raise ConfigValidationException(msg)
# Verify the private key against the certificate.
private_key_path = None
with config_provider.get_volume_file(SSL_FILENAMES[1]) as f:
private_key_path = f.name
@ -281,44 +290,10 @@ def _validate_ssl(config, user_obj, _):
# Only in testing.
return
# Validate the private key with the certificate.
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
context.use_certificate(cert)
try:
context.use_privatekey_file(private_key_path)
except:
raise ConfigValidationException('Could not parse key file. Is it a valid PEM private key?')
try:
context.check_privatekey()
except OpenSSL.SSL.Error as e:
raise ConfigValidationException('SSL key failed to validate: %s' % str(e))
# Verify the hostname matches the name in the certificate.
common_name = cert.get_subject().commonName
if common_name is None:
raise ConfigValidationException('Missing CommonName (CN) from SSL certificate')
# Build the list of allowed host patterns.
hosts = set([common_name])
# Find the DNS extension, if any.
for i in range(0, cert.get_extension_count()):
ext = cert.get_extension(i)
if ext.get_short_name() == 'subjectAltName':
value = str(ext)
hosts.update([host.strip()[4:] for host in value.split(',')])
# Check each host.
for host in hosts:
if fnmatch(config['SERVER_HOSTNAME'], host):
return
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
(', '.join(list(hosts)), config['SERVER_HOSTNAME']))
raise ConfigValidationException(msg)
certificate.validate_private_key(private_key_path)
except KeyInvalidException as kie:
raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message)
def _validate_ldap(config, user_obj, password):

81
util/security/ssl.py Normal file
View file

@ -0,0 +1,81 @@
from fnmatch import fnmatch
import OpenSSL
class CertInvalidException(Exception):
""" Exception raised when a certificate could not be parsed/loaded. """
pass
class KeyInvalidException(Exception):
""" Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
pass
def load_certificate(cert_contents):
""" Loads the certificate from the given contents and returns it or raises a CertInvalidException
on failure.
"""
try:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
return SSLCertificate(cert)
except OpenSSL.crypto.Error as ex:
raise CertInvalidException(ex.message[0][2])
_SUBJECT_ALT_NAME = 'subjectAltName'
class SSLCertificate(object):
""" Helper class for easier working with SSL certificates. """
def __init__(self, openssl_cert):
self.openssl_cert = openssl_cert
def validate_private_key(self, private_key_path):
""" Validates that the private key found at the given file path applies to this certificate.
Raises a KeyInvalidException on failure.
"""
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
context.use_certificate(self.openssl_cert)
try:
context.use_privatekey_file(private_key_path)
context.check_privatekey()
except OpenSSL.SSL.Error as ex:
raise KeyInvalidException(ex.message[0][2])
def matches_name(self, check_name):
""" Returns true if this SSL certificate matches the given DNS hostname. """
for dns_name in self.names:
if fnmatch(dns_name, check_name):
return True
return False
@property
def expired(self):
""" Returns whether the SSL certificate has expired. """
return self.openssl_cert.has_expired()
@property
def common_name(self):
""" Returns the defined common name for the certificate, if any. """
return self.openssl_cert.get_subject().commonName
@property
def names(self):
""" Returns all the DNS named to which the certificate applies. May be empty. """
dns_names = set()
common_name = self.common_name
if common_name is not None:
dns_names.add(common_name)
# Find the DNS extension, if any.
for i in range(0, self.openssl_cert.get_extension_count()):
ext = self.openssl_cert.get_extension(i)
if ext.get_short_name() == _SUBJECT_ALT_NAME:
value = str(ext)
for san_name in value.split(','):
san_name_trimmed = san_name.strip()
if san_name_trimmed.startswith('DNS:'):
dns_names.add(san_name_trimmed[4:])
return dns_names