initial import for Open Source 🎉

This commit is contained in:
Jimmy Zelinskie 2019-11-12 11:09:47 -05:00
parent 1898c361f3
commit 9c0dd3b722
2048 changed files with 218743 additions and 0 deletions

View file

34
util/security/aes.py Normal file
View file

@ -0,0 +1,34 @@
from __future__ import absolute_import
import base64
import hashlib
from Crypto import Random
from Crypto.Cipher import AES
class AESCipher(object):
""" Helper class for encrypting and decrypting data via AES.
Copied From: http://stackoverflow.com/a/21928790
"""
def __init__(self, key):
self.bs = 32
self.key = key
def encrypt(self, raw):
raw = self._pad(raw)
iv = Random.new().read(AES.block_size)
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return base64.b64encode(iv + cipher.encrypt(raw))
def decrypt(self, enc):
enc = base64.b64decode(enc)
iv = enc[:AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size:])).decode('utf-8')
def _pad(self, s):
return s + (self.bs - len(s) % self.bs) * chr(self.bs - len(s) % self.bs)
@staticmethod
def _unpad(s):
return s[:-ord(s[len(s)-1:])]

18
util/security/crypto.py Normal file
View file

@ -0,0 +1,18 @@
import base64
from cryptography.fernet import Fernet, InvalidToken
def encrypt_string(string, key):
""" Encrypts a string with the specified key. The key must be 32 raw bytes. """
f = Fernet(key)
return f.encrypt(string)
def decrypt_string(string, key, ttl=None):
""" Decrypts an encrypted string with the specified key. The key must be 32 raw bytes. """
f = Fernet(key)
try:
return f.decrypt(str(string), ttl=ttl)
except InvalidToken:
return None
except TypeError:
return None

View file

@ -0,0 +1,16 @@
import json
from hashlib import sha256
from util.canonicaljson import canonicalize
def canonical_kid(jwk):
"""This function returns the SHA256 hash of a canonical JWK.
Args:
jwk (object): the JWK for which a kid will be generated.
Returns:
string: the unique kid for the given JWK.
"""
return sha256(json.dumps(canonicalize(jwk), separators=(',', ':'))).hexdigest()

View file

@ -0,0 +1,79 @@
from cachetools.func import lru_cache
from data import model
from util.expiresdict import ExpiresDict, ExpiresEntry
from util.security import jwtutil
class CachingKey(object):
def __init__(self, service_key):
self._service_key = service_key
self._cached_public_key = None
@property
def public_key(self):
cached_key = self._cached_public_key
if cached_key is not None:
return cached_key
# Convert the JWK into a public key and cache it (since the conversion can take > 200ms).
public_key = jwtutil.jwk_dict_to_public_key(self._service_key.jwk)
self._cached_public_key = public_key
return public_key
class InstanceKeys(object):
""" InstanceKeys defines a helper class for interacting with the Quay instance service keys
used for JWT signing of registry tokens as well as requests from Quay to other services
such as Clair. Each container will have a single registered instance key.
"""
def __init__(self, app):
self.app = app
self.instance_keys = ExpiresDict(self._load_instance_keys)
def clear_cache(self):
""" Clears the cache of instance keys. """
self.instance_keys = ExpiresDict(self._load_instance_keys)
def _load_instance_keys(self):
# Load all the instance keys.
keys = {}
for key in model.service_keys.list_service_keys(self.service_name):
keys[key.kid] = ExpiresEntry(CachingKey(key), key.expiration_date)
return keys
@property
def service_name(self):
""" Returns the name of the instance key's service (i.e. 'quay'). """
return self.app.config['INSTANCE_SERVICE_KEY_SERVICE']
@property
def service_key_expiration(self):
""" Returns the defined expiration for instance service keys, in minutes. """
return self.app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120)
@property
@lru_cache(maxsize=1)
def local_key_id(self):
""" Returns the ID of the local instance service key. """
return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'])
@property
@lru_cache(maxsize=1)
def local_private_key(self):
""" Returns the private key of the local instance service key. """
return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_LOCATION'])
def get_service_key_public_key(self, kid):
""" Returns the public key associated with the given instance service key or None if none. """
caching_key = self.instance_keys.get(kid)
if caching_key is None:
return None
return caching_key.public_key
def _load_file_contents(path):
""" Returns the contents of the specified file path. """
with open(path) as f:
return f.read()

116
util/security/jwtutil.py Normal file
View file

@ -0,0 +1,116 @@
import re
from calendar import timegm
from datetime import datetime, timedelta
from jwt import PyJWT
from jwt.exceptions import (
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError,
InvalidAlgorithmError
)
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from jwkest.jwk import keyrep, RSAKey, ECKey
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
# ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD
# `none` here!
ALGORITHM_WHITELIST = [
'rs256'
]
class _StrictJWT(PyJWT):
""" _StrictJWT defines a JWT decoder with extra checks. """
@staticmethod
def _get_default_options():
# Weird syntax to call super on a staticmethod
defaults = super(_StrictJWT, _StrictJWT)._get_default_options()
defaults.update({
'require_exp': True,
'require_iat': True,
'require_nbf': True,
'exp_max_s': None,
})
return defaults
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
if options.get('exp_max_s') is not None:
if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
raise ValueError('exp_max_s option implies verify_expiration')
options['verify_exp'] = True
# Do all of the other checks
super(_StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
now = timegm(datetime.utcnow().utctimetuple())
self._reject_future_iat(payload, now, leeway)
if 'exp' in payload and options.get('exp_max_s') is not None:
# Validate that the expiration was not more than exp_max_s seconds after the issue time
# or in the absence of an issue time, more than exp_max_s in the future from now
# This will work because the parent method already checked the type of exp
expiration = datetime.utcfromtimestamp(int(payload['exp']))
max_signed_s = options.get('exp_max_s')
start_time = datetime.utcnow()
if 'iat' in payload:
start_time = datetime.utcfromtimestamp(int(payload['iat']))
if expiration > start_time + timedelta(seconds=max_signed_s):
raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
start_time)
def _reject_future_iat(self, payload, now, leeway):
try:
iat = int(payload['iat'])
except ValueError:
raise DecodeError('Issued At claim (iat) must be an integer.')
if iat > (now + leeway):
raise InvalidIssuedAtError('Issued At claim (iat) cannot be in'
' the future.')
def decode(jwt, key='', verify=True, algorithms=None, options=None,
**kwargs):
""" Decodes a JWT. """
if not algorithms:
raise InvalidAlgorithmError('algorithms must be specified')
normalized = set([a.lower() for a in algorithms])
if 'none' in normalized:
raise InvalidAlgorithmError('`none` algorithm is not allowed')
if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized):
raise InvalidAlgorithmError('Algorithms `%s` are not whitelisted. Allowed: %s' %
(algorithms, ALGORITHM_WHITELIST))
return _StrictJWT().decode(jwt, key, verify, algorithms, options, **kwargs)
def exp_max_s_option(max_exp_s):
""" Returns an options dictionary that sets the maximum expiration seconds for a JWT. """
return {
'exp_max_s': max_exp_s,
}
def jwk_dict_to_public_key(jwk):
""" Converts the specified JWK into a public key. """
jwkest_key = keyrep(jwk)
if isinstance(jwkest_key, RSAKey):
pycrypto_key = jwkest_key.key
return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend())
elif isinstance(jwkest_key, ECKey):
x, y = jwkest_key.get_key()
return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend())
raise Exception('Unsupported kind of JWK: %s', str(type(jwkest_key)))

View file

@ -0,0 +1,134 @@
import time
import jwt
import logging
from util.security import jwtutil
logger = logging.getLogger(__name__)
ANONYMOUS_SUB = '(anonymous)'
ALGORITHM = 'RS256'
CLAIM_TUF_ROOTS = 'com.apostille.roots'
CLAIM_TUF_ROOT = 'com.apostille.root'
QUAY_TUF_ROOT = 'quay'
SIGNER_TUF_ROOT = 'signer'
DISABLED_TUF_ROOT = '$disabled'
# The number of allowed seconds of clock skew for a JWT. The iat, nbf and exp are adjusted with this
# count.
JWT_CLOCK_SKEW_SECONDS = 30
class InvalidBearerTokenException(Exception):
pass
def decode_bearer_header(bearer_header, instance_keys, config, metric_queue=None):
""" decode_bearer_header decodes the given bearer header that contains an encoded JWT with both
a Key ID as well as the signed JWT and returns the decoded and validated JWT. On any error,
raises an InvalidBearerTokenException with the reason for failure.
"""
# Extract the jwt token from the header
match = jwtutil.TOKEN_REGEX.match(bearer_header)
if match is None:
raise InvalidBearerTokenException('Invalid bearer token format')
encoded_jwt = match.group(1)
logger.debug('encoded JWT: %s', encoded_jwt)
return decode_bearer_token(encoded_jwt, instance_keys, config, metric_queue=metric_queue)
def decode_bearer_token(bearer_token, instance_keys, config, metric_queue=None):
""" decode_bearer_token decodes the given bearer token that contains both a Key ID as well as the
encoded JWT and returns the decoded and validated JWT. On any error, raises an
InvalidBearerTokenException with the reason for failure.
"""
# Decode the key ID.
try:
headers = jwt.get_unverified_header(bearer_token)
except jwtutil.InvalidTokenError as ite:
logger.exception('Invalid token reason: %s', ite)
raise InvalidBearerTokenException(ite)
kid = headers.get('kid', None)
if kid is None:
logger.error('Missing kid header on encoded JWT: %s', bearer_token)
raise InvalidBearerTokenException('Missing kid header')
# Find the matching public key.
public_key = instance_keys.get_service_key_public_key(kid)
if public_key is None:
if metric_queue is not None:
metric_queue.invalid_instance_key_count.Inc(labelvalues=[kid])
logger.error('Could not find requested service key %s with encoded JWT: %s', kid, bearer_token)
raise InvalidBearerTokenException('Unknown service key')
# Load the JWT returned.
try:
expected_issuer = instance_keys.service_name
audience = config['SERVER_HOSTNAME']
max_signed_s = config.get('REGISTRY_JWT_AUTH_MAX_FRESH_S', 3660)
max_exp = jwtutil.exp_max_s_option(max_signed_s)
payload = jwtutil.decode(bearer_token, public_key, algorithms=[ALGORITHM], audience=audience,
issuer=expected_issuer, options=max_exp, leeway=JWT_CLOCK_SKEW_SECONDS)
except jwtutil.InvalidTokenError as ite:
logger.exception('Invalid token reason: %s', ite)
raise InvalidBearerTokenException(ite)
if not 'sub' in payload:
raise InvalidBearerTokenException('Missing sub field in JWT')
return payload
def generate_bearer_token(audience, subject, context, access, lifetime_s, instance_keys):
""" Generates a registry bearer token (without the 'Bearer ' portion) based on the given
information.
"""
return _generate_jwt_object(audience, subject, context, access, lifetime_s,
instance_keys.service_name, instance_keys.local_key_id,
instance_keys.local_private_key)
def _generate_jwt_object(audience, subject, context, access, lifetime_s, issuer, key_id,
private_key):
""" Generates a compact encoded JWT with the values specified. """
token_data = {
'iss': issuer,
'aud': audience,
'nbf': int(time.time()),
'iat': int(time.time()),
'exp': int(time.time() + lifetime_s),
'sub': subject,
'access': access,
'context': context,
}
token_headers = {
'kid': key_id,
}
return jwt.encode(token_data, private_key, ALGORITHM, headers=token_headers)
def build_context_and_subject(auth_context=None, tuf_roots=None):
""" Builds the custom context field for the JWT signed token and returns it,
along with the subject for the JWT signed token. """
# Serialize to a dictionary.
context = auth_context.to_signed_dict() if auth_context else {}
# TODO: remove once Apostille has been upgraded to not use the single root.
single_root = (tuf_roots.values()[0]
if tuf_roots is not None and len(tuf_roots) == 1
else DISABLED_TUF_ROOT)
context.update({
CLAIM_TUF_ROOTS: tuf_roots,
CLAIM_TUF_ROOT: single_root,
})
if not auth_context or auth_context.is_anonymous:
return (context, ANONYMOUS_SUB)
return (context, auth_context.authed_user.username if auth_context.authed_user else None)

28
util/security/secret.py Normal file
View file

@ -0,0 +1,28 @@
import itertools
import uuid
def convert_secret_key(config_secret_key):
""" Converts the secret key from the app config into a secret key that is usable by AES
Cipher. """
secret_key = None
# First try parsing the key as an int.
try:
big_int = int(config_secret_key)
secret_key = str(bytearray.fromhex('{:02x}'.format(big_int)))
except ValueError:
pass
# Next try parsing it as an UUID.
if secret_key is None:
try:
secret_key = uuid.UUID(config_secret_key).bytes
except ValueError:
pass
if secret_key is None:
secret_key = str(bytearray(map(ord, config_secret_key)))
# Otherwise, use the bytes directly.
assert len(secret_key)
return ''.join(itertools.islice(itertools.cycle(secret_key), 32))

82
util/security/signing.py Normal file
View file

@ -0,0 +1,82 @@
import gpgme
import os
import features
import logging
logger = logging.getLogger(__name__)
from StringIO import StringIO
class GPG2Signer(object):
""" Helper class for signing data using GPG2. """
def __init__(self, config, config_provider):
if not config.get('GPG2_PRIVATE_KEY_NAME'):
raise Exception('Missing configuration key GPG2_PRIVATE_KEY_NAME')
if not config.get('GPG2_PRIVATE_KEY_FILENAME'):
raise Exception('Missing configuration key GPG2_PRIVATE_KEY_FILENAME')
if not config.get('GPG2_PUBLIC_KEY_FILENAME'):
raise Exception('Missing configuration key GPG2_PUBLIC_KEY_FILENAME')
self._ctx = gpgme.Context()
self._ctx.armor = True
self._private_key_name = config['GPG2_PRIVATE_KEY_NAME']
self._public_key_filename = config['GPG2_PUBLIC_KEY_FILENAME']
self._config_provider = config_provider
if not config_provider.volume_file_exists(config['GPG2_PRIVATE_KEY_FILENAME']):
raise Exception('Missing key file %s' % config['GPG2_PRIVATE_KEY_FILENAME'])
with config_provider.get_volume_file(config['GPG2_PRIVATE_KEY_FILENAME'], mode='rb') as fp:
self._ctx.import_(fp)
@property
def name(self):
return 'gpg2'
def open_public_key_file(self):
return self._config_provider.get_volume_file(self._public_key_filename, mode='rb')
def detached_sign(self, stream):
""" Signs the given stream, returning the signature. """
ctx = self._ctx
try:
ctx.signers = [ctx.get_key(self._private_key_name)]
except:
raise Exception('Invalid private key name')
signature = StringIO()
new_sigs = ctx.sign(stream, signature, gpgme.SIG_MODE_DETACH)
signature.seek(0)
return signature.getvalue()
class Signer(object):
def __init__(self, app=None, config_provider=None):
self.app = app
if app is not None:
self.state = self.init_app(app, config_provider)
else:
self.state = None
def init_app(self, app, config_provider):
preference = app.config.get('SIGNING_ENGINE', None)
if preference is None:
return None
if not features.ACI_CONVERSION:
return None
try:
return SIGNING_ENGINES[preference](app.config, config_provider)
except:
logger.exception('Could not initialize signing engine')
def __getattr__(self, name):
return getattr(self.state, name, None)
SIGNING_ENGINES = {
'gpg2': GPG2Signer
}

12
util/security/ssh.py Normal file
View file

@ -0,0 +1,12 @@
from __future__ import absolute_import
from Crypto.PublicKey import RSA
def generate_ssh_keypair():
"""
Generates a new 2048 bit RSA public key in OpenSSH format and private key in PEM format.
"""
key = RSA.generate(2048)
public_key = key.publickey().exportKey('OpenSSH')
private_key = key.exportKey('PEM')
return (public_key, private_key)

81
util/security/ssl.py Normal file
View file

@ -0,0 +1,81 @@
from fnmatch import fnmatch
import OpenSSL
class CertInvalidException(Exception):
""" Exception raised when a certificate could not be parsed/loaded. """
pass
class KeyInvalidException(Exception):
""" Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
pass
def load_certificate(cert_contents):
""" Loads the certificate from the given contents and returns it or raises a CertInvalidException
on failure.
"""
try:
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
return SSLCertificate(cert)
except OpenSSL.crypto.Error as ex:
raise CertInvalidException(ex.args[0][0][2])
_SUBJECT_ALT_NAME = 'subjectAltName'
class SSLCertificate(object):
""" Helper class for easier working with SSL certificates. """
def __init__(self, openssl_cert):
self.openssl_cert = openssl_cert
def validate_private_key(self, private_key_path):
""" Validates that the private key found at the given file path applies to this certificate.
Raises a KeyInvalidException on failure.
"""
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
context.use_certificate(self.openssl_cert)
try:
context.use_privatekey_file(private_key_path)
context.check_privatekey()
except OpenSSL.SSL.Error as ex:
raise KeyInvalidException(ex.args[0][0][2])
def matches_name(self, check_name):
""" Returns true if this SSL certificate matches the given DNS hostname. """
for dns_name in self.names:
if fnmatch(check_name, dns_name):
return True
return False
@property
def expired(self):
""" Returns whether the SSL certificate has expired. """
return self.openssl_cert.has_expired()
@property
def common_name(self):
""" Returns the defined common name for the certificate, if any. """
return self.openssl_cert.get_subject().commonName
@property
def names(self):
""" Returns all the DNS named to which the certificate applies. May be empty. """
dns_names = set()
common_name = self.common_name
if common_name is not None:
dns_names.add(common_name)
# Find the DNS extension, if any.
for i in range(0, self.openssl_cert.get_extension_count()):
ext = self.openssl_cert.get_extension(i)
if ext.get_short_name() == _SUBJECT_ALT_NAME:
value = str(ext)
for san_name in value.split(','):
san_name_trimmed = san_name.strip()
if san_name_trimmed.startswith('DNS:'):
dns_names.add(san_name_trimmed[4:])
return dns_names

View file

View file

@ -0,0 +1,119 @@
import time
import pytest
import jwt
from Crypto.PublicKey import RSA
from jwkest.jwk import RSAKey
from util.security.jwtutil import (decode, exp_max_s_option, jwk_dict_to_public_key,
InvalidTokenError, InvalidAlgorithmError)
@pytest.fixture(scope="session")
def private_key():
return RSA.generate(2048)
@pytest.fixture(scope="session")
def private_key_pem(private_key):
return private_key.exportKey('PEM')
@pytest.fixture(scope="session")
def public_key(private_key):
return private_key.publickey().exportKey('PEM')
def _token_data(audience, subject, iss, iat=None, exp=None, nbf=None):
return {
'iss': iss,
'aud': audience,
'nbf': nbf() if nbf is not None else int(time.time()),
'iat': iat() if iat is not None else int(time.time()),
'exp': exp() if exp is not None else int(time.time() + 3600),
'sub': subject,
}
@pytest.mark.parametrize('aud, iss, nbf, iat, exp, expected_exception', [
pytest.param('invalidaudience', 'someissuer', None, None, None, 'Invalid audience',
id='invalid audience'),
pytest.param('someaudience', 'invalidissuer', None, None, None, 'Invalid issuer',
id='invalid issuer'),
pytest.param('someaudience', 'someissuer', lambda: time.time() + 120, None, None,
'The token is not yet valid',
id='invalid not before'),
pytest.param('someaudience', 'someissuer', None, lambda: time.time() + 120, None,
'Issued At claim',
id='issued at in future'),
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() - 100,
'Signature has expired',
id='already expired'),
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() + 10000,
'Token was signed for more than',
id='expiration too far in future'),
pytest.param('someaudience', 'someissuer', lambda: time.time() + 10, None, None,
None,
id='not before in future by within leeway'),
pytest.param('someaudience', 'someissuer', None, lambda: time.time() + 10, None,
None,
id='issued at in future but within leeway'),
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() - 10,
None,
id='expiration in past but within leeway'),
])
def test_decode_jwt_validation(aud, iss, nbf, iat, exp, expected_exception, private_key_pem,
public_key):
token = jwt.encode(_token_data(aud, 'subject', iss, iat, exp, nbf), private_key_pem, 'RS256')
if expected_exception is not None:
with pytest.raises(InvalidTokenError) as ite:
max_exp = exp_max_s_option(3600)
decode(token, public_key, algorithms=['RS256'], audience='someaudience',
issuer='someissuer', options=max_exp, leeway=60)
assert ite.match(expected_exception)
else:
max_exp = exp_max_s_option(3600)
decode(token, public_key, algorithms=['RS256'], audience='someaudience',
issuer='someissuer', options=max_exp, leeway=60)
def test_decode_jwt_invalid_key(private_key_pem):
# Encode with the test private key.
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
# Try to decode with a different public key.
another_public_key = RSA.generate(2048).publickey().exportKey('PEM')
with pytest.raises(InvalidTokenError) as ite:
max_exp = exp_max_s_option(3600)
decode(token, another_public_key, algorithms=['RS256'], audience='aud',
issuer='someissuer', options=max_exp, leeway=60)
assert ite.match('Signature verification failed')
def test_decode_jwt_invalid_algorithm(private_key_pem, public_key):
# Encode with the test private key.
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
# Attempt to decode but only with a different algorithm than that used.
with pytest.raises(InvalidAlgorithmError) as ite:
max_exp = exp_max_s_option(3600)
decode(token, public_key, algorithms=['ES256'], audience='aud',
issuer='someissuer', options=max_exp, leeway=60)
assert ite.match('are not whitelisted')
def test_jwk_dict_to_public_key(private_key, private_key_pem):
public_key = private_key.publickey()
jwk = RSAKey(key=private_key.publickey()).serialize()
converted = jwk_dict_to_public_key(jwk)
# Encode with the test private key.
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
# Decode with the converted key.
max_exp = exp_max_s_option(3600)
decode(token, converted, algorithms=['RS256'], audience='aud',
issuer='someissuer', options=max_exp, leeway=60)

View file

@ -0,0 +1,116 @@
from tempfile import NamedTemporaryFile
import pytest
from OpenSSL import crypto
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
def generate_test_cert(hostname='somehostname', san_list=None, expires=1000000):
""" Generates a test SSL certificate and returns the certificate data and private key data. """
# Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/
# Create a key pair.
k = crypto.PKey()
k.generate_key(crypto.TYPE_RSA, 1024)
# Create a self-signed cert.
cert = crypto.X509()
cert.get_subject().CN = hostname
# Add the subjectAltNames (if necessary).
if san_list is not None:
cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))])
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
cert.gmtime_adj_notAfter(expires)
cert.set_issuer(cert.get_subject())
cert.set_pubkey(k)
cert.sign(k, 'sha1')
# Dump the certificate and private key in PEM format.
cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
return (cert_data, key_data)
def test_load_certificate():
# Try loading an invalid certificate.
with pytest.raises(CertInvalidException):
load_certificate('someinvalidcontents')
# Load a valid certificate.
(public_key_data, _) = generate_test_cert()
cert = load_certificate(public_key_data)
assert not cert.expired
assert cert.names == set(['somehostname'])
assert cert.matches_name('somehostname')
def test_expired_certificate():
(public_key_data, _) = generate_test_cert(expires=-100)
cert = load_certificate(public_key_data)
assert cert.expired
def test_hostnames():
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz'])
cert = load_certificate(public_key_data)
assert cert.names == set(['foo', 'bar', 'baz'])
for name in cert.names:
assert cert.matches_name(name)
def test_wildcard_hostnames():
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:*.bar'])
cert = load_certificate(public_key_data)
assert cert.names == set(['foo', '*.bar'])
for name in cert.names:
assert cert.matches_name(name)
assert cert.matches_name('something.bar')
assert cert.matches_name('somethingelse.bar')
assert cert.matches_name('cool.bar')
assert not cert.matches_name('*')
def test_nondns_hostnames():
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['URI:yarg'])
cert = load_certificate(public_key_data)
assert cert.names == set(['foo'])
def test_validate_private_key():
(public_key_data, private_key_data) = generate_test_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write(private_key_data)
private_key.seek(0)
cert = load_certificate(public_key_data)
cert.validate_private_key(private_key.name)
def test_invalid_private_key():
(public_key_data, _) = generate_test_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write('somerandomdata')
private_key.seek(0)
cert = load_certificate(public_key_data)
with pytest.raises(KeyInvalidException):
cert.validate_private_key(private_key.name)
def test_mismatch_private_key():
(public_key_data, _) = generate_test_cert()
(_, private_key_data) = generate_test_cert()
private_key = NamedTemporaryFile(delete=True)
private_key.write(private_key_data)
private_key.seek(0)
cert = load_certificate(public_key_data)
with pytest.raises(KeyInvalidException):
cert.validate_private_key(private_key.name)

32
util/security/token.py Normal file
View file

@ -0,0 +1,32 @@
from collections import namedtuple
import base64
DELIMITER = ':'
DecodedToken = namedtuple('DecodedToken', ['public_code', 'private_token'])
def encode_public_private_token(public_code, private_token, allow_public_only=False):
# NOTE: This is for legacy tokens where the private token part is None. We should remove this
# once older installations have been fully converted over (if at all possible).
if private_token is None:
assert allow_public_only
return public_code
assert isinstance(private_token, basestring)
return base64.b64encode('%s%s%s' % (public_code, DELIMITER, private_token))
def decode_public_private_token(encoded, allow_public_only=False):
try:
decoded = base64.b64decode(encoded)
except (ValueError, TypeError):
if not allow_public_only:
return None
return DecodedToken(encoded, None)
parts = decoded.split(DELIMITER, 2)
if len(parts) != 2:
return None
return DecodedToken(*parts)