85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
from fnmatch import fnmatch
|
|
|
|
import OpenSSL
|
|
|
|
|
|
class CertInvalidException(Exception):
|
|
""" Exception raised when a certificate could not be parsed/loaded. """
|
|
pass
|
|
|
|
|
|
class KeyInvalidException(Exception):
|
|
""" Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
|
|
pass
|
|
|
|
|
|
def load_certificate(cert_contents):
|
|
""" Loads the certificate from the given contents and returns it or raises a CertInvalidException
|
|
on failure.
|
|
"""
|
|
try:
|
|
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
|
|
return SSLCertificate(cert)
|
|
except OpenSSL.crypto.Error as ex:
|
|
raise CertInvalidException(ex.message[0][2])
|
|
|
|
|
|
_SUBJECT_ALT_NAME = 'subjectAltName'
|
|
|
|
|
|
class SSLCertificate(object):
|
|
""" Helper class for easier working with SSL certificates. """
|
|
|
|
def __init__(self, openssl_cert):
|
|
self.openssl_cert = openssl_cert
|
|
|
|
def validate_private_key(self, private_key_path):
|
|
""" Validates that the private key found at the given file path applies to this certificate.
|
|
Raises a KeyInvalidException on failure.
|
|
"""
|
|
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
|
|
context.use_certificate(self.openssl_cert)
|
|
|
|
try:
|
|
context.use_privatekey_file(private_key_path)
|
|
context.check_privatekey()
|
|
except OpenSSL.SSL.Error as ex:
|
|
raise KeyInvalidException(ex.message[0][2])
|
|
|
|
def matches_name(self, check_name):
|
|
""" Returns true if this SSL certificate matches the given DNS hostname. """
|
|
for dns_name in self.names:
|
|
if fnmatch(check_name, dns_name):
|
|
return True
|
|
|
|
return False
|
|
|
|
@property
|
|
def expired(self):
|
|
""" Returns whether the SSL certificate has expired. """
|
|
return self.openssl_cert.has_expired()
|
|
|
|
@property
|
|
def common_name(self):
|
|
""" Returns the defined common name for the certificate, if any. """
|
|
return self.openssl_cert.get_subject().commonName
|
|
|
|
@property
|
|
def names(self):
|
|
""" Returns all the DNS named to which the certificate applies. May be empty. """
|
|
dns_names = set()
|
|
common_name = self.common_name
|
|
if common_name is not None:
|
|
dns_names.add(common_name)
|
|
|
|
# Find the DNS extension, if any.
|
|
for i in range(0, self.openssl_cert.get_extension_count()):
|
|
ext = self.openssl_cert.get_extension(i)
|
|
if ext.get_short_name() == _SUBJECT_ALT_NAME:
|
|
value = str(ext)
|
|
for san_name in value.split(','):
|
|
san_name_trimmed = san_name.strip()
|
|
if san_name_trimmed.startswith('DNS:'):
|
|
dns_names.add(san_name_trimmed[4:])
|
|
|
|
return dns_names
|