diff --git a/test/test_ssl_util.py b/test/test_ssl_util.py deleted file mode 100644 index 0bfd7be05..000000000 --- a/test/test_ssl_util.py +++ /dev/null @@ -1,121 +0,0 @@ -import unittest - -from tempfile import NamedTemporaryFile -from OpenSSL import crypto - -from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException - -def generate_test_cert(hostname='somehostname', san_list=None, expires=1000000): - """ Generates a test SSL certificate and returns the certificate data and private key data. """ - - # Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/ - # Create a key pair. - k = crypto.PKey() - k.generate_key(crypto.TYPE_RSA, 1024) - - # Create a self-signed cert. - cert = crypto.X509() - cert.get_subject().CN = hostname - - # Add the subjectAltNames (if necessary). - if san_list is not None: - cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))]) - - cert.set_serial_number(1000) - cert.gmtime_adj_notBefore(0) - cert.gmtime_adj_notAfter(expires) - cert.set_issuer(cert.get_subject()) - - cert.set_pubkey(k) - cert.sign(k, 'sha1') - - # Dump the certificate and private key in PEM format. - cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) - key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k) - - return (cert_data, key_data) - - -class TestSSLCertificate(unittest.TestCase): - def test_load_certificate(self): - # Try loading an invalid certificate. - with self.assertRaisesRegexp(CertInvalidException, 'no start line'): - load_certificate('someinvalidcontents') - - # Load a valid certificate. - (public_key_data, _) = generate_test_cert() - - cert = load_certificate(public_key_data) - self.assertFalse(cert.expired) - self.assertEquals(set(['somehostname']), cert.names) - self.assertTrue(cert.matches_name('somehostname')) - - def test_expired_certificate(self): - (public_key_data, _) = generate_test_cert(expires=-100) - - cert = load_certificate(public_key_data) - self.assertTrue(cert.expired) - - def test_hostnames(self): - (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz']) - cert = load_certificate(public_key_data) - self.assertEquals(set(['foo', 'bar', 'baz']), cert.names) - - for name in cert.names: - self.assertTrue(cert.matches_name(name)) - - def test_wildcard_hostnames(self): - (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:*.bar']) - cert = load_certificate(public_key_data) - self.assertEquals(set(['foo', '*.bar']), cert.names) - - for name in cert.names: - self.assertTrue(cert.matches_name(name)) - - self.assertTrue(cert.matches_name('something.bar')) - self.assertTrue(cert.matches_name('somethingelse.bar')) - self.assertTrue(cert.matches_name('cool.bar')) - self.assertFalse(cert.matches_name('*')) - - def test_nondns_hostnames(self): - (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['URI:yarg']) - cert = load_certificate(public_key_data) - self.assertEquals(set(['foo']), cert.names) - - def test_validate_private_key(self): - (public_key_data, private_key_data) = generate_test_cert() - - private_key = NamedTemporaryFile(delete=True) - private_key.write(private_key_data) - private_key.seek(0) - - cert = load_certificate(public_key_data) - cert.validate_private_key(private_key.name) - - def test_invalid_private_key(self): - (public_key_data, _) = generate_test_cert() - - private_key = NamedTemporaryFile(delete=True) - private_key.write('somerandomdata') - private_key.seek(0) - - cert = load_certificate(public_key_data) - with self.assertRaisesRegexp(KeyInvalidException, 'no start line'): - cert.validate_private_key(private_key.name) - - def test_mismatch_private_key(self): - (public_key_data, _) = generate_test_cert() - (_, private_key_data) = generate_test_cert() - - private_key = NamedTemporaryFile(delete=True) - private_key.write(private_key_data) - private_key.seek(0) - - cert = load_certificate(public_key_data) - with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'): - cert.validate_private_key(private_key.name) - - - -if __name__ == '__main__': - unittest.main() diff --git a/util/security/test/test_ssl_util.py b/util/security/test/test_ssl_util.py new file mode 100644 index 000000000..d38de483d --- /dev/null +++ b/util/security/test/test_ssl_util.py @@ -0,0 +1,116 @@ +from tempfile import NamedTemporaryFile + +import pytest + +from OpenSSL import crypto + +from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException + +def generate_test_cert(hostname='somehostname', san_list=None, expires=1000000): + """ Generates a test SSL certificate and returns the certificate data and private key data. """ + + # Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/ + # Create a key pair. + k = crypto.PKey() + k.generate_key(crypto.TYPE_RSA, 1024) + + # Create a self-signed cert. + cert = crypto.X509() + cert.get_subject().CN = hostname + + # Add the subjectAltNames (if necessary). + if san_list is not None: + cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))]) + + cert.set_serial_number(1000) + cert.gmtime_adj_notBefore(0) + cert.gmtime_adj_notAfter(expires) + cert.set_issuer(cert.get_subject()) + + cert.set_pubkey(k) + cert.sign(k, 'sha1') + + # Dump the certificate and private key in PEM format. + cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) + key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k) + + return (cert_data, key_data) + + +def test_load_certificate(): + # Try loading an invalid certificate. + with pytest.raises(CertInvalidException): + load_certificate('someinvalidcontents') + + # Load a valid certificate. + (public_key_data, _) = generate_test_cert() + + cert = load_certificate(public_key_data) + assert not cert.expired + assert cert.names == set(['somehostname']) + assert cert.matches_name('somehostname') + +def test_expired_certificate(): + (public_key_data, _) = generate_test_cert(expires=-100) + + cert = load_certificate(public_key_data) + assert cert.expired + +def test_hostnames(): + (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz']) + cert = load_certificate(public_key_data) + assert cert.names == set(['foo', 'bar', 'baz']) + + for name in cert.names: + assert cert.matches_name(name) + +def test_wildcard_hostnames(): + (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:*.bar']) + cert = load_certificate(public_key_data) + assert cert.names == set(['foo', '*.bar']) + + for name in cert.names: + assert cert.matches_name(name) + + assert cert.matches_name('something.bar') + assert cert.matches_name('somethingelse.bar') + assert cert.matches_name('cool.bar') + assert not cert.matches_name('*') + +def test_nondns_hostnames(): + (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['URI:yarg']) + cert = load_certificate(public_key_data) + assert cert.names == set(['foo']) + +def test_validate_private_key(): + (public_key_data, private_key_data) = generate_test_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write(private_key_data) + private_key.seek(0) + + cert = load_certificate(public_key_data) + cert.validate_private_key(private_key.name) + +def test_invalid_private_key(): + (public_key_data, _) = generate_test_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write('somerandomdata') + private_key.seek(0) + + cert = load_certificate(public_key_data) + with pytest.raises(KeyInvalidException): + cert.validate_private_key(private_key.name) + +def test_mismatch_private_key(): + (public_key_data, _) = generate_test_cert() + (_, private_key_data) = generate_test_cert() + + private_key = NamedTemporaryFile(delete=True) + private_key.write(private_key_data) + private_key.seek(0) + + cert = load_certificate(public_key_data) + with pytest.raises(KeyInvalidException): + cert.validate_private_key(private_key.name)