Add SSL certificate utility and tests

This commit is contained in:
Joseph Schorr 2016-12-09 18:31:02 -05:00
parent f1c9965edf
commit 3a24871422
5 changed files with 205 additions and 46 deletions

105
test/test_ssl_util.py Normal file
View file

@ -0,0 +1,105 @@
import unittest
from tempfile import NamedTemporaryFile
from OpenSSL import crypto
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
class TestSSLCertificate(unittest.TestCase):
def _generate_cert(self, hostname='somehostname', san_list=None, expires=1000000):
# Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/
# Create a key pair.
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 1024)
# Create a self-signed cert.
cert = crypto.X509()
cert.get_subject().CN = hostname
# Add the subjectAltNames (if necessary).
if san_list is not None:
cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))])
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(expires)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(k)
cert.sign(k, 'sha1')
# Dump the certificate and private key in PEM format.
cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
return (cert_data, key_data)
def test_load_certificate(self):
# Try loading an invalid certificate.
with self.assertRaisesRegexp(CertInvalidException, 'no start line'):
load_certificate('someinvalidcontents')
# Load a valid certificate.
(public_key_data, _) = self._generate_cert()
cert = load_certificate(public_key_data)
self.assertFalse(cert.expired)
self.assertEquals(set(['somehostname']), cert.names)
self.assertTrue(cert.matches_name('somehostname'))
def test_expired_certificate(self):
(public_key_data, _) = self._generate_cert(expires=-100)
cert = load_certificate(public_key_data)
self.assertTrue(cert.expired)
def test_hostnames(self):
(public_key_data, _) = self._generate_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz'])
cert = load_certificate(public_key_data)
self.assertEquals(set(['foo', 'bar', 'baz']), cert.names)
for name in cert.names:
self.assertTrue(cert.matches_name(name))
def test_nondns_hostnames(self):
(public_key_data, _) = self._generate_cert(hostname='foo', san_list=['URI:yarg'])
cert = load_certificate(public_key_data)
self.assertEquals(set(['foo']), cert.names)
def test_validate_private_key(self):
(public_key_data, private_key_data) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write(private_key_data)
private_key.seek(0)
cert = load_certificate(public_key_data)
cert.validate_private_key(private_key.name)
def test_invalid_private_key(self):
(public_key_data, _) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write('somerandomdata')
private_key.seek(0)
cert = load_certificate(public_key_data)
with self.assertRaisesRegexp(KeyInvalidException, 'no start line'):
cert.validate_private_key(private_key.name)
def test_mismatch_private_key(self):
(public_key_data, _) = self._generate_cert()
(_, private_key_data) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write(private_key_data)
private_key.seek(0)
cert = load_certificate(public_key_data)
with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'):
cert.validate_private_key(private_key.name)
if __name__ == '__main__':
unittest.main()

View file

@ -143,8 +143,6 @@ class TestValidateConfig(unittest.TestCase):
'PREFERRED_URL_SCHEME': 'https',
})
# TODO(jschorr): Add SSL verification tests once file lookup is fixed.
def test_validate_keystone(self):
with self.assertRaisesRegexp(ConfigValidationException,
'Verification of superuser someuser failed'):

View file

@ -79,7 +79,7 @@ class KubernetesConfigProvider(FileConfigProvider):
# as an error, as it seems to be a common issue.
namespace_url = 'namespaces/%s' % (QE_NAMESPACE)
response = self._execute_k8s_api('GET', namespace_url)
if response.status_code / 100 != 2:
if response.status_code // 100 != 2:
msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE
raise CannotWriteConfigException(msg)

View file

@ -3,10 +3,8 @@ import subprocess
import time
from StringIO import StringIO
from fnmatch import fnmatch
from hashlib import sha1
import OpenSSL
import ldap
import peewee
import redis
@ -28,6 +26,7 @@ from util.config.oauth import GoogleOAuthConfig, GithubOAuthConfig, GitLabOAuthC
from util.secscan.api import SecurityScannerAPI
from util.registry.torrent import torrent_jwt
from util.security.signing import SIGNING_ENGINES
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
logger = logging.getLogger(__name__)
@ -257,22 +256,32 @@ def _validate_ssl(config, user_obj, _):
if config.get('EXTERNAL_TLS_TERMINATION', False) is True:
return
# Verify that we have all the required SSL files.
for filename in SSL_FILENAMES:
if not config_provider.volume_file_exists(filename):
raise ConfigValidationException('Missing required SSL file: %s' % filename)
# Read the contents of the SSL certificate.
with config_provider.get_volume_file(SSL_FILENAMES[0]) as f:
cert_contents = f.read()
# Validate the certificate.
try:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
except:
raise ConfigValidationException('Could not parse certificate file. Is it a valid PEM certificate?')
certificate = load_certificate(cert_contents)
except CertInvalidException as cie:
raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message)
if cert.has_expired():
# Verify the certificate has not expired.
if certificate.expired:
raise ConfigValidationException('The specified SSL certificate has expired.')
# Verify the hostname matches the name in the certificate.
if not certificate.matches_name(config['SERVER_HOSTNAME']):
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
(', '.join(list(certificate.names)), config['SERVER_HOSTNAME']))
raise ConfigValidationException(msg)
# Verify the private key against the certificate.
private_key_path = None
with config_provider.get_volume_file(SSL_FILENAMES[1]) as f:
private_key_path = f.name
@ -281,44 +290,10 @@ def _validate_ssl(config, user_obj, _):
# Only in testing.
return
# Validate the private key with the certificate.
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
context.use_certificate(cert)
try:
context.use_privatekey_file(private_key_path)
except:
raise ConfigValidationException('Could not parse key file. Is it a valid PEM private key?')
try:
context.check_privatekey()
except OpenSSL.SSL.Error as e:
raise ConfigValidationException('SSL key failed to validate: %s' % str(e))
# Verify the hostname matches the name in the certificate.
common_name = cert.get_subject().commonName
if common_name is None:
raise ConfigValidationException('Missing CommonName (CN) from SSL certificate')
# Build the list of allowed host patterns.
hosts = set([common_name])
# Find the DNS extension, if any.
for i in range(0, cert.get_extension_count()):
ext = cert.get_extension(i)
if ext.get_short_name() == 'subjectAltName':
value = str(ext)
hosts.update([host.strip()[4:] for host in value.split(',')])
# Check each host.
for host in hosts:
if fnmatch(config['SERVER_HOSTNAME'], host):
return
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
(', '.join(list(hosts)), config['SERVER_HOSTNAME']))
raise ConfigValidationException(msg)
certificate.validate_private_key(private_key_path)
except KeyInvalidException as kie:
raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message)
def _validate_ldap(config, user_obj, password):

81
util/security/ssl.py Normal file
View file

@ -0,0 +1,81 @@
from fnmatch import fnmatch
import OpenSSL
class CertInvalidException(Exception):
""" Exception raised when a certificate could not be parsed/loaded. """
pass
class KeyInvalidException(Exception):
""" Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
pass
def load_certificate(cert_contents):
""" Loads the certificate from the given contents and returns it or raises a CertInvalidException
on failure.
"""
try:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
return SSLCertificate(cert)
except OpenSSL.crypto.Error as ex:
raise CertInvalidException(ex.message[0][2])
_SUBJECT_ALT_NAME = 'subjectAltName'
class SSLCertificate(object):
""" Helper class for easier working with SSL certificates. """
def __init__(self, openssl_cert):
self.openssl_cert = openssl_cert
def validate_private_key(self, private_key_path):
""" Validates that the private key found at the given file path applies to this certificate.
Raises a KeyInvalidException on failure.
"""
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
context.use_certificate(self.openssl_cert)
try:
context.use_privatekey_file(private_key_path)
context.check_privatekey()
except OpenSSL.SSL.Error as ex:
raise KeyInvalidException(ex.message[0][2])
def matches_name(self, check_name):
""" Returns true if this SSL certificate matches the given DNS hostname. """
for dns_name in self.names:
if fnmatch(dns_name, check_name):
return True
return False
@property
def expired(self):
""" Returns whether the SSL certificate has expired. """
return self.openssl_cert.has_expired()
@property
def common_name(self):
""" Returns the defined common name for the certificate, if any. """
return self.openssl_cert.get_subject().commonName
@property
def names(self):
""" Returns all the DNS named to which the certificate applies. May be empty. """
dns_names = set()
common_name = self.common_name
if common_name is not None:
dns_names.add(common_name)
# Find the DNS extension, if any.
for i in range(0, self.openssl_cert.get_extension_count()):
ext = self.openssl_cert.get_extension(i)
if ext.get_short_name() == _SUBJECT_ALT_NAME:
value = str(ext)
for san_name in value.split(','):
san_name_trimmed = san_name.strip()
if san_name_trimmed.startswith('DNS:'):
dns_names.add(san_name_trimmed[4:])
return dns_names