from fnmatch import fnmatch import OpenSSL class CertInvalidException(Exception): """ Exception raised when a certificate could not be parsed/loaded. """ pass class KeyInvalidException(Exception): """ Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """ pass def load_certificate(cert_contents): """ Loads the certificate from the given contents and returns it or raises a CertInvalidException on failure. """ try: cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents) return SSLCertificate(cert) except OpenSSL.crypto.Error as ex: raise CertInvalidException(ex.message[0][2]) _SUBJECT_ALT_NAME = 'subjectAltName' class SSLCertificate(object): """ Helper class for easier working with SSL certificates. """ def __init__(self, openssl_cert): self.openssl_cert = openssl_cert def validate_private_key(self, private_key_path): """ Validates that the private key found at the given file path applies to this certificate. Raises a KeyInvalidException on failure. """ context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD) context.use_certificate(self.openssl_cert) try: context.use_privatekey_file(private_key_path) context.check_privatekey() except OpenSSL.SSL.Error as ex: raise KeyInvalidException(ex.message[0][2]) def matches_name(self, check_name): """ Returns true if this SSL certificate matches the given DNS hostname. """ for dns_name in self.names: if fnmatch(dns_name, check_name): return True return False @property def expired(self): """ Returns whether the SSL certificate has expired. """ return self.openssl_cert.has_expired() @property def common_name(self): """ Returns the defined common name for the certificate, if any. """ return self.openssl_cert.get_subject().commonName @property def names(self): """ Returns all the DNS named to which the certificate applies. May be empty. """ dns_names = set() common_name = self.common_name if common_name is not None: dns_names.add(common_name) # Find the DNS extension, if any. for i in range(0, self.openssl_cert.get_extension_count()): ext = self.openssl_cert.get_extension(i) if ext.get_short_name() == _SUBJECT_ALT_NAME: value = str(ext) for san_name in value.split(','): san_name_trimmed = san_name.strip() if san_name_trimmed.startswith('DNS:'): dns_names.add(san_name_trimmed[4:]) return dns_names