import unittest from tempfile import NamedTemporaryFile from OpenSSL import crypto from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException def generate_test_cert(hostname='somehostname', san_list=None, expires=1000000): """ Generates a test SSL certificate and returns the certificate data and private key data. """ # Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/ # Create a key pair. k = crypto.PKey() k.generate_key(crypto.TYPE_RSA, 1024) # Create a self-signed cert. cert = crypto.X509() cert.get_subject().CN = hostname # Add the subjectAltNames (if necessary). if san_list is not None: cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))]) cert.set_serial_number(1000) cert.gmtime_adj_notBefore(0) cert.gmtime_adj_notAfter(expires) cert.set_issuer(cert.get_subject()) cert.set_pubkey(k) cert.sign(k, 'sha1') # Dump the certificate and private key in PEM format. cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k) return (cert_data, key_data) class TestSSLCertificate(unittest.TestCase): def test_load_certificate(self): # Try loading an invalid certificate. with self.assertRaisesRegexp(CertInvalidException, 'no start line'): load_certificate('someinvalidcontents') # Load a valid certificate. (public_key_data, _) = generate_test_cert() cert = load_certificate(public_key_data) self.assertFalse(cert.expired) self.assertEquals(set(['somehostname']), cert.names) self.assertTrue(cert.matches_name('somehostname')) def test_expired_certificate(self): (public_key_data, _) = generate_test_cert(expires=-100) cert = load_certificate(public_key_data) self.assertTrue(cert.expired) def test_hostnames(self): (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz']) cert = load_certificate(public_key_data) self.assertEquals(set(['foo', 'bar', 'baz']), cert.names) for name in cert.names: self.assertTrue(cert.matches_name(name)) def test_wildcard_hostnames(self): (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:*.bar']) cert = load_certificate(public_key_data) self.assertEquals(set(['foo', '*.bar']), cert.names) for name in cert.names: self.assertTrue(cert.matches_name(name)) self.assertTrue(cert.matches_name('something.bar')) self.assertTrue(cert.matches_name('somethingelse.bar')) self.assertTrue(cert.matches_name('cool.bar')) self.assertFalse(cert.matches_name('*')) def test_nondns_hostnames(self): (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['URI:yarg']) cert = load_certificate(public_key_data) self.assertEquals(set(['foo']), cert.names) def test_validate_private_key(self): (public_key_data, private_key_data) = generate_test_cert() private_key = NamedTemporaryFile(delete=True) private_key.write(private_key_data) private_key.seek(0) cert = load_certificate(public_key_data) cert.validate_private_key(private_key.name) def test_invalid_private_key(self): (public_key_data, _) = generate_test_cert() private_key = NamedTemporaryFile(delete=True) private_key.write('somerandomdata') private_key.seek(0) cert = load_certificate(public_key_data) with self.assertRaisesRegexp(KeyInvalidException, 'no start line'): cert.validate_private_key(private_key.name) def test_mismatch_private_key(self): (public_key_data, _) = generate_test_cert() (_, private_key_data) = generate_test_cert() private_key = NamedTemporaryFile(delete=True) private_key.write(private_key_data) private_key.seek(0) cert = load_certificate(public_key_data) with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'): cert.validate_private_key(private_key.name) if __name__ == '__main__': unittest.main()