Add SSL certificate utility and tests

This commit is contained in:
Joseph Schorr 2016-12-09 18:31:02 -05:00
parent f1c9965edf
commit 3a24871422
5 changed files with 205 additions and 46 deletions

105
test/test_ssl_util.py Normal file
View file

@ -0,0 +1,105 @@
import unittest
from tempfile import NamedTemporaryFile
from OpenSSL import crypto
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
class TestSSLCertificate(unittest.TestCase):
def _generate_cert(self, hostname='somehostname', san_list=None, expires=1000000):
# Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/
# Create a key pair.
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 1024)
# Create a self-signed cert.
cert = crypto.X509()
cert.get_subject().CN = hostname
# Add the subjectAltNames (if necessary).
if san_list is not None:
cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))])
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(expires)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(k)
cert.sign(k, 'sha1')
# Dump the certificate and private key in PEM format.
cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
return (cert_data, key_data)
def test_load_certificate(self):
# Try loading an invalid certificate.
with self.assertRaisesRegexp(CertInvalidException, 'no start line'):
load_certificate('someinvalidcontents')
# Load a valid certificate.
(public_key_data, _) = self._generate_cert()
cert = load_certificate(public_key_data)
self.assertFalse(cert.expired)
self.assertEquals(set(['somehostname']), cert.names)
self.assertTrue(cert.matches_name('somehostname'))
def test_expired_certificate(self):
(public_key_data, _) = self._generate_cert(expires=-100)
cert = load_certificate(public_key_data)
self.assertTrue(cert.expired)
def test_hostnames(self):
(public_key_data, _) = self._generate_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz'])
cert = load_certificate(public_key_data)
self.assertEquals(set(['foo', 'bar', 'baz']), cert.names)
for name in cert.names:
self.assertTrue(cert.matches_name(name))
def test_nondns_hostnames(self):
(public_key_data, _) = self._generate_cert(hostname='foo', san_list=['URI:yarg'])
cert = load_certificate(public_key_data)
self.assertEquals(set(['foo']), cert.names)
def test_validate_private_key(self):
(public_key_data, private_key_data) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write(private_key_data)
private_key.seek(0)
cert = load_certificate(public_key_data)
cert.validate_private_key(private_key.name)
def test_invalid_private_key(self):
(public_key_data, _) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write('somerandomdata')
private_key.seek(0)
cert = load_certificate(public_key_data)
with self.assertRaisesRegexp(KeyInvalidException, 'no start line'):
cert.validate_private_key(private_key.name)
def test_mismatch_private_key(self):
(public_key_data, _) = self._generate_cert()
(_, private_key_data) = self._generate_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write(private_key_data)
private_key.seek(0)
cert = load_certificate(public_key_data)
with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'):
cert.validate_private_key(private_key.name)
if __name__ == '__main__':
unittest.main()

View file

@ -143,8 +143,6 @@ class TestValidateConfig(unittest.TestCase):
'PREFERRED_URL_SCHEME': 'https', 'PREFERRED_URL_SCHEME': 'https',
}) })
# TODO(jschorr): Add SSL verification tests once file lookup is fixed.
def test_validate_keystone(self): def test_validate_keystone(self):
with self.assertRaisesRegexp(ConfigValidationException, with self.assertRaisesRegexp(ConfigValidationException,
'Verification of superuser someuser failed'): 'Verification of superuser someuser failed'):

View file

@ -79,7 +79,7 @@ class KubernetesConfigProvider(FileConfigProvider):
# as an error, as it seems to be a common issue. # as an error, as it seems to be a common issue.
namespace_url = 'namespaces/%s' % (QE_NAMESPACE) namespace_url = 'namespaces/%s' % (QE_NAMESPACE)
response = self._execute_k8s_api('GET', namespace_url) response = self._execute_k8s_api('GET', namespace_url)
if response.status_code / 100 != 2: if response.status_code // 100 != 2:
msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE msg = 'A Kubernetes namespace with name `%s` must be created to save config' % QE_NAMESPACE
raise CannotWriteConfigException(msg) raise CannotWriteConfigException(msg)

View file

@ -3,10 +3,8 @@ import subprocess
import time import time
from StringIO import StringIO from StringIO import StringIO
from fnmatch import fnmatch
from hashlib import sha1 from hashlib import sha1
import OpenSSL
import ldap import ldap
import peewee import peewee
import redis import redis
@ -28,6 +26,7 @@ from util.config.oauth import GoogleOAuthConfig, GithubOAuthConfig, GitLabOAuthC
from util.secscan.api import SecurityScannerAPI from util.secscan.api import SecurityScannerAPI
from util.registry.torrent import torrent_jwt from util.registry.torrent import torrent_jwt
from util.security.signing import SIGNING_ENGINES from util.security.signing import SIGNING_ENGINES
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -257,22 +256,32 @@ def _validate_ssl(config, user_obj, _):
if config.get('EXTERNAL_TLS_TERMINATION', False) is True: if config.get('EXTERNAL_TLS_TERMINATION', False) is True:
return return
# Verify that we have all the required SSL files.
for filename in SSL_FILENAMES: for filename in SSL_FILENAMES:
if not config_provider.volume_file_exists(filename): if not config_provider.volume_file_exists(filename):
raise ConfigValidationException('Missing required SSL file: %s' % filename) raise ConfigValidationException('Missing required SSL file: %s' % filename)
# Read the contents of the SSL certificate.
with config_provider.get_volume_file(SSL_FILENAMES[0]) as f: with config_provider.get_volume_file(SSL_FILENAMES[0]) as f:
cert_contents = f.read() cert_contents = f.read()
# Validate the certificate. # Validate the certificate.
try: try:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents) certificate = load_certificate(cert_contents)
except: except CertInvalidException as cie:
raise ConfigValidationException('Could not parse certificate file. Is it a valid PEM certificate?') raise ConfigValidationException('Could not load SSL certificate: %s' % cie.message)
if cert.has_expired(): # Verify the certificate has not expired.
if certificate.expired:
raise ConfigValidationException('The specified SSL certificate has expired.') raise ConfigValidationException('The specified SSL certificate has expired.')
# Verify the hostname matches the name in the certificate.
if not certificate.matches_name(config['SERVER_HOSTNAME']):
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
(', '.join(list(certificate.names)), config['SERVER_HOSTNAME']))
raise ConfigValidationException(msg)
# Verify the private key against the certificate.
private_key_path = None private_key_path = None
with config_provider.get_volume_file(SSL_FILENAMES[1]) as f: with config_provider.get_volume_file(SSL_FILENAMES[1]) as f:
private_key_path = f.name private_key_path = f.name
@ -281,44 +290,10 @@ def _validate_ssl(config, user_obj, _):
# Only in testing. # Only in testing.
return return
# Validate the private key with the certificate.
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
context.use_certificate(cert)
try: try:
context.use_privatekey_file(private_key_path) certificate.validate_private_key(private_key_path)
except: except KeyInvalidException as kie:
raise ConfigValidationException('Could not parse key file. Is it a valid PEM private key?') raise ConfigValidationException('SSL private key failed to validate: %s' % kie.message)
try:
context.check_privatekey()
except OpenSSL.SSL.Error as e:
raise ConfigValidationException('SSL key failed to validate: %s' % str(e))
# Verify the hostname matches the name in the certificate.
common_name = cert.get_subject().commonName
if common_name is None:
raise ConfigValidationException('Missing CommonName (CN) from SSL certificate')
# Build the list of allowed host patterns.
hosts = set([common_name])
# Find the DNS extension, if any.
for i in range(0, cert.get_extension_count()):
ext = cert.get_extension(i)
if ext.get_short_name() == 'subjectAltName':
value = str(ext)
hosts.update([host.strip()[4:] for host in value.split(',')])
# Check each host.
for host in hosts:
if fnmatch(config['SERVER_HOSTNAME'], host):
return
msg = ('Supported names "%s" in SSL cert do not match server hostname "%s"' %
(', '.join(list(hosts)), config['SERVER_HOSTNAME']))
raise ConfigValidationException(msg)
def _validate_ldap(config, user_obj, password): def _validate_ldap(config, user_obj, password):

81
util/security/ssl.py Normal file
View file

@ -0,0 +1,81 @@
from fnmatch import fnmatch
import OpenSSL
class CertInvalidException(Exception):
""" Exception raised when a certificate could not be parsed/loaded. """
pass
class KeyInvalidException(Exception):
""" Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
pass
def load_certificate(cert_contents):
""" Loads the certificate from the given contents and returns it or raises a CertInvalidException
on failure.
"""
try:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
return SSLCertificate(cert)
except OpenSSL.crypto.Error as ex:
raise CertInvalidException(ex.message[0][2])
_SUBJECT_ALT_NAME = 'subjectAltName'
class SSLCertificate(object):
""" Helper class for easier working with SSL certificates. """
def __init__(self, openssl_cert):
self.openssl_cert = openssl_cert
def validate_private_key(self, private_key_path):
""" Validates that the private key found at the given file path applies to this certificate.
Raises a KeyInvalidException on failure.
"""
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
context.use_certificate(self.openssl_cert)
try:
context.use_privatekey_file(private_key_path)
context.check_privatekey()
except OpenSSL.SSL.Error as ex:
raise KeyInvalidException(ex.message[0][2])
def matches_name(self, check_name):
""" Returns true if this SSL certificate matches the given DNS hostname. """
for dns_name in self.names:
if fnmatch(dns_name, check_name):
return True
return False
@property
def expired(self):
""" Returns whether the SSL certificate has expired. """
return self.openssl_cert.has_expired()
@property
def common_name(self):
""" Returns the defined common name for the certificate, if any. """
return self.openssl_cert.get_subject().commonName
@property
def names(self):
""" Returns all the DNS named to which the certificate applies. May be empty. """
dns_names = set()
common_name = self.common_name
if common_name is not None:
dns_names.add(common_name)
# Find the DNS extension, if any.
for i in range(0, self.openssl_cert.get_extension_count()):
ext = self.openssl_cert.get_extension(i)
if ext.get_short_name() == _SUBJECT_ALT_NAME:
value = str(ext)
for san_name in value.split(','):
san_name_trimmed = san_name.strip()
if san_name_trimmed.startswith('DNS:'):
dns_names.add(san_name_trimmed[4:])
return dns_names