Move SSL util tests to pytest
This commit is contained in:
		
							parent
							
								
									a830d53212
								
							
						
					
					
						commit
						894b754121
					
				
					 2 changed files with 116 additions and 121 deletions
				
			
		
							
								
								
									
										116
									
								
								util/security/test/test_ssl_util.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										116
									
								
								util/security/test/test_ssl_util.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,116 @@ | |||
| from tempfile import NamedTemporaryFile | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| from OpenSSL import crypto | ||||
| 
 | ||||
| from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException | ||||
| 
 | ||||
| def generate_test_cert(hostname='somehostname', san_list=None, expires=1000000): | ||||
|   """ Generates a test SSL certificate and returns the certificate data and private key data. """ | ||||
| 
 | ||||
|   # Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/ | ||||
|   # Create a key pair. | ||||
|   k = crypto.PKey() | ||||
|   k.generate_key(crypto.TYPE_RSA, 1024) | ||||
| 
 | ||||
|   # Create a self-signed cert. | ||||
|   cert = crypto.X509() | ||||
|   cert.get_subject().CN = hostname | ||||
| 
 | ||||
|   # Add the subjectAltNames (if necessary). | ||||
|   if san_list is not None: | ||||
|     cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))]) | ||||
| 
 | ||||
|   cert.set_serial_number(1000) | ||||
|   cert.gmtime_adj_notBefore(0) | ||||
|   cert.gmtime_adj_notAfter(expires) | ||||
|   cert.set_issuer(cert.get_subject()) | ||||
| 
 | ||||
|   cert.set_pubkey(k) | ||||
|   cert.sign(k, 'sha1') | ||||
| 
 | ||||
|   # Dump the certificate and private key in PEM format. | ||||
|   cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert) | ||||
|   key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k) | ||||
| 
 | ||||
|   return (cert_data, key_data) | ||||
| 
 | ||||
| 
 | ||||
| def test_load_certificate(): | ||||
|   # Try loading an invalid certificate. | ||||
|   with pytest.raises(CertInvalidException): | ||||
|     load_certificate('someinvalidcontents') | ||||
| 
 | ||||
|   # Load a valid certificate. | ||||
|   (public_key_data, _) = generate_test_cert() | ||||
| 
 | ||||
|   cert = load_certificate(public_key_data) | ||||
|   assert not cert.expired | ||||
|   assert cert.names == set(['somehostname']) | ||||
|   assert cert.matches_name('somehostname') | ||||
| 
 | ||||
| def test_expired_certificate(): | ||||
|   (public_key_data, _) = generate_test_cert(expires=-100) | ||||
| 
 | ||||
|   cert = load_certificate(public_key_data) | ||||
|   assert cert.expired | ||||
| 
 | ||||
| def test_hostnames(): | ||||
|   (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz']) | ||||
|   cert = load_certificate(public_key_data) | ||||
|   assert cert.names == set(['foo', 'bar', 'baz']) | ||||
| 
 | ||||
|   for name in cert.names: | ||||
|     assert cert.matches_name(name) | ||||
| 
 | ||||
| def test_wildcard_hostnames(): | ||||
|   (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:*.bar']) | ||||
|   cert = load_certificate(public_key_data) | ||||
|   assert cert.names == set(['foo', '*.bar']) | ||||
| 
 | ||||
|   for name in cert.names: | ||||
|     assert cert.matches_name(name) | ||||
| 
 | ||||
|   assert cert.matches_name('something.bar') | ||||
|   assert cert.matches_name('somethingelse.bar') | ||||
|   assert cert.matches_name('cool.bar') | ||||
|   assert not cert.matches_name('*') | ||||
| 
 | ||||
| def test_nondns_hostnames(): | ||||
|   (public_key_data, _) = generate_test_cert(hostname='foo', san_list=['URI:yarg']) | ||||
|   cert = load_certificate(public_key_data) | ||||
|   assert cert.names == set(['foo']) | ||||
| 
 | ||||
| def test_validate_private_key(): | ||||
|   (public_key_data, private_key_data) = generate_test_cert() | ||||
| 
 | ||||
|   private_key = NamedTemporaryFile(delete=True) | ||||
|   private_key.write(private_key_data) | ||||
|   private_key.seek(0) | ||||
| 
 | ||||
|   cert = load_certificate(public_key_data) | ||||
|   cert.validate_private_key(private_key.name) | ||||
| 
 | ||||
| def test_invalid_private_key(): | ||||
|   (public_key_data, _) = generate_test_cert() | ||||
| 
 | ||||
|   private_key = NamedTemporaryFile(delete=True) | ||||
|   private_key.write('somerandomdata') | ||||
|   private_key.seek(0) | ||||
| 
 | ||||
|   cert = load_certificate(public_key_data) | ||||
|   with pytest.raises(KeyInvalidException): | ||||
|     cert.validate_private_key(private_key.name) | ||||
| 
 | ||||
| def test_mismatch_private_key(): | ||||
|   (public_key_data, _) = generate_test_cert() | ||||
|   (_, private_key_data) = generate_test_cert() | ||||
| 
 | ||||
|   private_key = NamedTemporaryFile(delete=True) | ||||
|   private_key.write(private_key_data) | ||||
|   private_key.seek(0) | ||||
| 
 | ||||
|   cert = load_certificate(public_key_data) | ||||
|   with pytest.raises(KeyInvalidException): | ||||
|     cert.validate_private_key(private_key.name) | ||||
		Reference in a new issue