Move SSL util tests to pytest
This commit is contained in:
parent
a830d53212
commit
894b754121
2 changed files with 116 additions and 121 deletions
|
@ -1,121 +0,0 @@
|
||||||
import unittest
|
|
||||||
|
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
from OpenSSL import crypto
|
|
||||||
|
|
||||||
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
|
|
||||||
|
|
||||||
def generate_test_cert(hostname='somehostname', san_list=None, expires=1000000):
|
|
||||||
""" Generates a test SSL certificate and returns the certificate data and private key data. """
|
|
||||||
|
|
||||||
# Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/
|
|
||||||
# Create a key pair.
|
|
||||||
k = crypto.PKey()
|
|
||||||
k.generate_key(crypto.TYPE_RSA, 1024)
|
|
||||||
|
|
||||||
# Create a self-signed cert.
|
|
||||||
cert = crypto.X509()
|
|
||||||
cert.get_subject().CN = hostname
|
|
||||||
|
|
||||||
# Add the subjectAltNames (if necessary).
|
|
||||||
if san_list is not None:
|
|
||||||
cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))])
|
|
||||||
|
|
||||||
cert.set_serial_number(1000)
|
|
||||||
cert.gmtime_adj_notBefore(0)
|
|
||||||
cert.gmtime_adj_notAfter(expires)
|
|
||||||
cert.set_issuer(cert.get_subject())
|
|
||||||
|
|
||||||
cert.set_pubkey(k)
|
|
||||||
cert.sign(k, 'sha1')
|
|
||||||
|
|
||||||
# Dump the certificate and private key in PEM format.
|
|
||||||
cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
|
|
||||||
key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
|
|
||||||
|
|
||||||
return (cert_data, key_data)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSSLCertificate(unittest.TestCase):
|
|
||||||
def test_load_certificate(self):
|
|
||||||
# Try loading an invalid certificate.
|
|
||||||
with self.assertRaisesRegexp(CertInvalidException, 'no start line'):
|
|
||||||
load_certificate('someinvalidcontents')
|
|
||||||
|
|
||||||
# Load a valid certificate.
|
|
||||||
(public_key_data, _) = generate_test_cert()
|
|
||||||
|
|
||||||
cert = load_certificate(public_key_data)
|
|
||||||
self.assertFalse(cert.expired)
|
|
||||||
self.assertEquals(set(['somehostname']), cert.names)
|
|
||||||
self.assertTrue(cert.matches_name('somehostname'))
|
|
||||||
|
|
||||||
def test_expired_certificate(self):
|
|
||||||
(public_key_data, _) = generate_test_cert(expires=-100)
|
|
||||||
|
|
||||||
cert = load_certificate(public_key_data)
|
|
||||||
self.assertTrue(cert.expired)
|
|
||||||
|
|
||||||
def test_hostnames(self):
|
|
||||||
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz'])
|
|
||||||
cert = load_certificate(public_key_data)
|
|
||||||
self.assertEquals(set(['foo', 'bar', 'baz']), cert.names)
|
|
||||||
|
|
||||||
for name in cert.names:
|
|
||||||
self.assertTrue(cert.matches_name(name))
|
|
||||||
|
|
||||||
def test_wildcard_hostnames(self):
|
|
||||||
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:*.bar'])
|
|
||||||
cert = load_certificate(public_key_data)
|
|
||||||
self.assertEquals(set(['foo', '*.bar']), cert.names)
|
|
||||||
|
|
||||||
for name in cert.names:
|
|
||||||
self.assertTrue(cert.matches_name(name))
|
|
||||||
|
|
||||||
self.assertTrue(cert.matches_name('something.bar'))
|
|
||||||
self.assertTrue(cert.matches_name('somethingelse.bar'))
|
|
||||||
self.assertTrue(cert.matches_name('cool.bar'))
|
|
||||||
self.assertFalse(cert.matches_name('*'))
|
|
||||||
|
|
||||||
def test_nondns_hostnames(self):
|
|
||||||
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['URI:yarg'])
|
|
||||||
cert = load_certificate(public_key_data)
|
|
||||||
self.assertEquals(set(['foo']), cert.names)
|
|
||||||
|
|
||||||
def test_validate_private_key(self):
|
|
||||||
(public_key_data, private_key_data) = generate_test_cert()
|
|
||||||
|
|
||||||
private_key = NamedTemporaryFile(delete=True)
|
|
||||||
private_key.write(private_key_data)
|
|
||||||
private_key.seek(0)
|
|
||||||
|
|
||||||
cert = load_certificate(public_key_data)
|
|
||||||
cert.validate_private_key(private_key.name)
|
|
||||||
|
|
||||||
def test_invalid_private_key(self):
|
|
||||||
(public_key_data, _) = generate_test_cert()
|
|
||||||
|
|
||||||
private_key = NamedTemporaryFile(delete=True)
|
|
||||||
private_key.write('somerandomdata')
|
|
||||||
private_key.seek(0)
|
|
||||||
|
|
||||||
cert = load_certificate(public_key_data)
|
|
||||||
with self.assertRaisesRegexp(KeyInvalidException, 'no start line'):
|
|
||||||
cert.validate_private_key(private_key.name)
|
|
||||||
|
|
||||||
def test_mismatch_private_key(self):
|
|
||||||
(public_key_data, _) = generate_test_cert()
|
|
||||||
(_, private_key_data) = generate_test_cert()
|
|
||||||
|
|
||||||
private_key = NamedTemporaryFile(delete=True)
|
|
||||||
private_key.write(private_key_data)
|
|
||||||
private_key.seek(0)
|
|
||||||
|
|
||||||
cert = load_certificate(public_key_data)
|
|
||||||
with self.assertRaisesRegexp(KeyInvalidException, 'key values mismatch'):
|
|
||||||
cert.validate_private_key(private_key.name)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
unittest.main()
|
|
116
util/security/test/test_ssl_util.py
Normal file
116
util/security/test/test_ssl_util.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from OpenSSL import crypto
|
||||||
|
|
||||||
|
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
|
||||||
|
|
||||||
|
def generate_test_cert(hostname='somehostname', san_list=None, expires=1000000):
|
||||||
|
""" Generates a test SSL certificate and returns the certificate data and private key data. """
|
||||||
|
|
||||||
|
# Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/
|
||||||
|
# Create a key pair.
|
||||||
|
k = crypto.PKey()
|
||||||
|
k.generate_key(crypto.TYPE_RSA, 1024)
|
||||||
|
|
||||||
|
# Create a self-signed cert.
|
||||||
|
cert = crypto.X509()
|
||||||
|
cert.get_subject().CN = hostname
|
||||||
|
|
||||||
|
# Add the subjectAltNames (if necessary).
|
||||||
|
if san_list is not None:
|
||||||
|
cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))])
|
||||||
|
|
||||||
|
cert.set_serial_number(1000)
|
||||||
|
cert.gmtime_adj_notBefore(0)
|
||||||
|
cert.gmtime_adj_notAfter(expires)
|
||||||
|
cert.set_issuer(cert.get_subject())
|
||||||
|
|
||||||
|
cert.set_pubkey(k)
|
||||||
|
cert.sign(k, 'sha1')
|
||||||
|
|
||||||
|
# Dump the certificate and private key in PEM format.
|
||||||
|
cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
|
||||||
|
key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
|
||||||
|
|
||||||
|
return (cert_data, key_data)
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_certificate():
|
||||||
|
# Try loading an invalid certificate.
|
||||||
|
with pytest.raises(CertInvalidException):
|
||||||
|
load_certificate('someinvalidcontents')
|
||||||
|
|
||||||
|
# Load a valid certificate.
|
||||||
|
(public_key_data, _) = generate_test_cert()
|
||||||
|
|
||||||
|
cert = load_certificate(public_key_data)
|
||||||
|
assert not cert.expired
|
||||||
|
assert cert.names == set(['somehostname'])
|
||||||
|
assert cert.matches_name('somehostname')
|
||||||
|
|
||||||
|
def test_expired_certificate():
|
||||||
|
(public_key_data, _) = generate_test_cert(expires=-100)
|
||||||
|
|
||||||
|
cert = load_certificate(public_key_data)
|
||||||
|
assert cert.expired
|
||||||
|
|
||||||
|
def test_hostnames():
|
||||||
|
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz'])
|
||||||
|
cert = load_certificate(public_key_data)
|
||||||
|
assert cert.names == set(['foo', 'bar', 'baz'])
|
||||||
|
|
||||||
|
for name in cert.names:
|
||||||
|
assert cert.matches_name(name)
|
||||||
|
|
||||||
|
def test_wildcard_hostnames():
|
||||||
|
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:*.bar'])
|
||||||
|
cert = load_certificate(public_key_data)
|
||||||
|
assert cert.names == set(['foo', '*.bar'])
|
||||||
|
|
||||||
|
for name in cert.names:
|
||||||
|
assert cert.matches_name(name)
|
||||||
|
|
||||||
|
assert cert.matches_name('something.bar')
|
||||||
|
assert cert.matches_name('somethingelse.bar')
|
||||||
|
assert cert.matches_name('cool.bar')
|
||||||
|
assert not cert.matches_name('*')
|
||||||
|
|
||||||
|
def test_nondns_hostnames():
|
||||||
|
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['URI:yarg'])
|
||||||
|
cert = load_certificate(public_key_data)
|
||||||
|
assert cert.names == set(['foo'])
|
||||||
|
|
||||||
|
def test_validate_private_key():
|
||||||
|
(public_key_data, private_key_data) = generate_test_cert()
|
||||||
|
|
||||||
|
private_key = NamedTemporaryFile(delete=True)
|
||||||
|
private_key.write(private_key_data)
|
||||||
|
private_key.seek(0)
|
||||||
|
|
||||||
|
cert = load_certificate(public_key_data)
|
||||||
|
cert.validate_private_key(private_key.name)
|
||||||
|
|
||||||
|
def test_invalid_private_key():
|
||||||
|
(public_key_data, _) = generate_test_cert()
|
||||||
|
|
||||||
|
private_key = NamedTemporaryFile(delete=True)
|
||||||
|
private_key.write('somerandomdata')
|
||||||
|
private_key.seek(0)
|
||||||
|
|
||||||
|
cert = load_certificate(public_key_data)
|
||||||
|
with pytest.raises(KeyInvalidException):
|
||||||
|
cert.validate_private_key(private_key.name)
|
||||||
|
|
||||||
|
def test_mismatch_private_key():
|
||||||
|
(public_key_data, _) = generate_test_cert()
|
||||||
|
(_, private_key_data) = generate_test_cert()
|
||||||
|
|
||||||
|
private_key = NamedTemporaryFile(delete=True)
|
||||||
|
private_key.write(private_key_data)
|
||||||
|
private_key.seek(0)
|
||||||
|
|
||||||
|
cert = load_certificate(public_key_data)
|
||||||
|
with pytest.raises(KeyInvalidException):
|
||||||
|
cert.validate_private_key(private_key.name)
|
Reference in a new issue