initial import for Open Source 🎉
This commit is contained in:
parent
1898c361f3
commit
9c0dd3b722
2048 changed files with 218743 additions and 0 deletions
0
util/security/__init__.py
Normal file
0
util/security/__init__.py
Normal file
34
util/security/aes.py
Normal file
34
util/security/aes.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
from Crypto import Random
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
class AESCipher(object):
|
||||
""" Helper class for encrypting and decrypting data via AES.
|
||||
|
||||
Copied From: http://stackoverflow.com/a/21928790
|
||||
"""
|
||||
def __init__(self, key):
|
||||
self.bs = 32
|
||||
self.key = key
|
||||
|
||||
def encrypt(self, raw):
|
||||
raw = self._pad(raw)
|
||||
iv = Random.new().read(AES.block_size)
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
return base64.b64encode(iv + cipher.encrypt(raw))
|
||||
|
||||
def decrypt(self, enc):
|
||||
enc = base64.b64decode(enc)
|
||||
iv = enc[:AES.block_size]
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
return self._unpad(cipher.decrypt(enc[AES.block_size:])).decode('utf-8')
|
||||
|
||||
def _pad(self, s):
|
||||
return s + (self.bs - len(s) % self.bs) * chr(self.bs - len(s) % self.bs)
|
||||
|
||||
@staticmethod
|
||||
def _unpad(s):
|
||||
return s[:-ord(s[len(s)-1:])]
|
18
util/security/crypto.py
Normal file
18
util/security/crypto.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import base64
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
def encrypt_string(string, key):
|
||||
""" Encrypts a string with the specified key. The key must be 32 raw bytes. """
|
||||
f = Fernet(key)
|
||||
return f.encrypt(string)
|
||||
|
||||
def decrypt_string(string, key, ttl=None):
|
||||
""" Decrypts an encrypted string with the specified key. The key must be 32 raw bytes. """
|
||||
f = Fernet(key)
|
||||
try:
|
||||
return f.decrypt(str(string), ttl=ttl)
|
||||
except InvalidToken:
|
||||
return None
|
||||
except TypeError:
|
||||
return None
|
16
util/security/fingerprint.py
Normal file
16
util/security/fingerprint.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import json
|
||||
|
||||
from hashlib import sha256
|
||||
from util.canonicaljson import canonicalize
|
||||
|
||||
def canonical_kid(jwk):
|
||||
"""This function returns the SHA256 hash of a canonical JWK.
|
||||
|
||||
Args:
|
||||
jwk (object): the JWK for which a kid will be generated.
|
||||
|
||||
Returns:
|
||||
string: the unique kid for the given JWK.
|
||||
|
||||
"""
|
||||
return sha256(json.dumps(canonicalize(jwk), separators=(',', ':'))).hexdigest()
|
79
util/security/instancekeys.py
Normal file
79
util/security/instancekeys.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
from cachetools.func import lru_cache
|
||||
from data import model
|
||||
from util.expiresdict import ExpiresDict, ExpiresEntry
|
||||
from util.security import jwtutil
|
||||
|
||||
|
||||
class CachingKey(object):
|
||||
def __init__(self, service_key):
|
||||
self._service_key = service_key
|
||||
self._cached_public_key = None
|
||||
|
||||
@property
|
||||
def public_key(self):
|
||||
cached_key = self._cached_public_key
|
||||
if cached_key is not None:
|
||||
return cached_key
|
||||
|
||||
# Convert the JWK into a public key and cache it (since the conversion can take > 200ms).
|
||||
public_key = jwtutil.jwk_dict_to_public_key(self._service_key.jwk)
|
||||
self._cached_public_key = public_key
|
||||
return public_key
|
||||
|
||||
|
||||
class InstanceKeys(object):
|
||||
""" InstanceKeys defines a helper class for interacting with the Quay instance service keys
|
||||
used for JWT signing of registry tokens as well as requests from Quay to other services
|
||||
such as Clair. Each container will have a single registered instance key.
|
||||
"""
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.instance_keys = ExpiresDict(self._load_instance_keys)
|
||||
|
||||
def clear_cache(self):
|
||||
""" Clears the cache of instance keys. """
|
||||
self.instance_keys = ExpiresDict(self._load_instance_keys)
|
||||
|
||||
def _load_instance_keys(self):
|
||||
# Load all the instance keys.
|
||||
keys = {}
|
||||
for key in model.service_keys.list_service_keys(self.service_name):
|
||||
keys[key.kid] = ExpiresEntry(CachingKey(key), key.expiration_date)
|
||||
|
||||
return keys
|
||||
|
||||
@property
|
||||
def service_name(self):
|
||||
""" Returns the name of the instance key's service (i.e. 'quay'). """
|
||||
return self.app.config['INSTANCE_SERVICE_KEY_SERVICE']
|
||||
|
||||
@property
|
||||
def service_key_expiration(self):
|
||||
""" Returns the defined expiration for instance service keys, in minutes. """
|
||||
return self.app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120)
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def local_key_id(self):
|
||||
""" Returns the ID of the local instance service key. """
|
||||
return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'])
|
||||
|
||||
@property
|
||||
@lru_cache(maxsize=1)
|
||||
def local_private_key(self):
|
||||
""" Returns the private key of the local instance service key. """
|
||||
return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_LOCATION'])
|
||||
|
||||
def get_service_key_public_key(self, kid):
|
||||
""" Returns the public key associated with the given instance service key or None if none. """
|
||||
caching_key = self.instance_keys.get(kid)
|
||||
if caching_key is None:
|
||||
return None
|
||||
|
||||
return caching_key.public_key
|
||||
|
||||
|
||||
def _load_file_contents(path):
|
||||
""" Returns the contents of the specified file path. """
|
||||
with open(path) as f:
|
||||
return f.read()
|
116
util/security/jwtutil.py
Normal file
116
util/security/jwtutil.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
import re
|
||||
|
||||
from calendar import timegm
|
||||
from datetime import datetime, timedelta
|
||||
from jwt import PyJWT
|
||||
from jwt.exceptions import (
|
||||
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
|
||||
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError,
|
||||
InvalidAlgorithmError
|
||||
)
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
|
||||
from jwkest.jwk import keyrep, RSAKey, ECKey
|
||||
|
||||
|
||||
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
|
||||
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
|
||||
|
||||
# ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD
|
||||
# `none` here!
|
||||
ALGORITHM_WHITELIST = [
|
||||
'rs256'
|
||||
]
|
||||
|
||||
class _StrictJWT(PyJWT):
|
||||
""" _StrictJWT defines a JWT decoder with extra checks. """
|
||||
|
||||
@staticmethod
|
||||
def _get_default_options():
|
||||
# Weird syntax to call super on a staticmethod
|
||||
defaults = super(_StrictJWT, _StrictJWT)._get_default_options()
|
||||
defaults.update({
|
||||
'require_exp': True,
|
||||
'require_iat': True,
|
||||
'require_nbf': True,
|
||||
'exp_max_s': None,
|
||||
})
|
||||
return defaults
|
||||
|
||||
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
|
||||
if options.get('exp_max_s') is not None:
|
||||
if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
|
||||
raise ValueError('exp_max_s option implies verify_expiration')
|
||||
|
||||
options['verify_exp'] = True
|
||||
|
||||
# Do all of the other checks
|
||||
super(_StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
|
||||
|
||||
now = timegm(datetime.utcnow().utctimetuple())
|
||||
self._reject_future_iat(payload, now, leeway)
|
||||
|
||||
if 'exp' in payload and options.get('exp_max_s') is not None:
|
||||
# Validate that the expiration was not more than exp_max_s seconds after the issue time
|
||||
# or in the absence of an issue time, more than exp_max_s in the future from now
|
||||
|
||||
# This will work because the parent method already checked the type of exp
|
||||
expiration = datetime.utcfromtimestamp(int(payload['exp']))
|
||||
max_signed_s = options.get('exp_max_s')
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
if 'iat' in payload:
|
||||
start_time = datetime.utcfromtimestamp(int(payload['iat']))
|
||||
|
||||
if expiration > start_time + timedelta(seconds=max_signed_s):
|
||||
raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
|
||||
start_time)
|
||||
|
||||
def _reject_future_iat(self, payload, now, leeway):
|
||||
try:
|
||||
iat = int(payload['iat'])
|
||||
except ValueError:
|
||||
raise DecodeError('Issued At claim (iat) must be an integer.')
|
||||
|
||||
if iat > (now + leeway):
|
||||
raise InvalidIssuedAtError('Issued At claim (iat) cannot be in'
|
||||
' the future.')
|
||||
|
||||
|
||||
def decode(jwt, key='', verify=True, algorithms=None, options=None,
|
||||
**kwargs):
|
||||
""" Decodes a JWT. """
|
||||
if not algorithms:
|
||||
raise InvalidAlgorithmError('algorithms must be specified')
|
||||
|
||||
normalized = set([a.lower() for a in algorithms])
|
||||
if 'none' in normalized:
|
||||
raise InvalidAlgorithmError('`none` algorithm is not allowed')
|
||||
|
||||
if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized):
|
||||
raise InvalidAlgorithmError('Algorithms `%s` are not whitelisted. Allowed: %s' %
|
||||
(algorithms, ALGORITHM_WHITELIST))
|
||||
|
||||
return _StrictJWT().decode(jwt, key, verify, algorithms, options, **kwargs)
|
||||
|
||||
|
||||
def exp_max_s_option(max_exp_s):
|
||||
""" Returns an options dictionary that sets the maximum expiration seconds for a JWT. """
|
||||
return {
|
||||
'exp_max_s': max_exp_s,
|
||||
}
|
||||
|
||||
|
||||
def jwk_dict_to_public_key(jwk):
|
||||
""" Converts the specified JWK into a public key. """
|
||||
jwkest_key = keyrep(jwk)
|
||||
if isinstance(jwkest_key, RSAKey):
|
||||
pycrypto_key = jwkest_key.key
|
||||
return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend())
|
||||
elif isinstance(jwkest_key, ECKey):
|
||||
x, y = jwkest_key.get_key()
|
||||
return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend())
|
||||
|
||||
raise Exception('Unsupported kind of JWK: %s', str(type(jwkest_key)))
|
134
util/security/registry_jwt.py
Normal file
134
util/security/registry_jwt.py
Normal file
|
@ -0,0 +1,134 @@
|
|||
import time
|
||||
import jwt
|
||||
import logging
|
||||
|
||||
from util.security import jwtutil
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ANONYMOUS_SUB = '(anonymous)'
|
||||
ALGORITHM = 'RS256'
|
||||
CLAIM_TUF_ROOTS = 'com.apostille.roots'
|
||||
CLAIM_TUF_ROOT = 'com.apostille.root'
|
||||
QUAY_TUF_ROOT = 'quay'
|
||||
SIGNER_TUF_ROOT = 'signer'
|
||||
DISABLED_TUF_ROOT = '$disabled'
|
||||
|
||||
# The number of allowed seconds of clock skew for a JWT. The iat, nbf and exp are adjusted with this
|
||||
# count.
|
||||
JWT_CLOCK_SKEW_SECONDS = 30
|
||||
|
||||
|
||||
class InvalidBearerTokenException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def decode_bearer_header(bearer_header, instance_keys, config, metric_queue=None):
|
||||
""" decode_bearer_header decodes the given bearer header that contains an encoded JWT with both
|
||||
a Key ID as well as the signed JWT and returns the decoded and validated JWT. On any error,
|
||||
raises an InvalidBearerTokenException with the reason for failure.
|
||||
"""
|
||||
# Extract the jwt token from the header
|
||||
match = jwtutil.TOKEN_REGEX.match(bearer_header)
|
||||
if match is None:
|
||||
raise InvalidBearerTokenException('Invalid bearer token format')
|
||||
|
||||
encoded_jwt = match.group(1)
|
||||
logger.debug('encoded JWT: %s', encoded_jwt)
|
||||
return decode_bearer_token(encoded_jwt, instance_keys, config, metric_queue=metric_queue)
|
||||
|
||||
|
||||
def decode_bearer_token(bearer_token, instance_keys, config, metric_queue=None):
|
||||
""" decode_bearer_token decodes the given bearer token that contains both a Key ID as well as the
|
||||
encoded JWT and returns the decoded and validated JWT. On any error, raises an
|
||||
InvalidBearerTokenException with the reason for failure.
|
||||
"""
|
||||
# Decode the key ID.
|
||||
try:
|
||||
headers = jwt.get_unverified_header(bearer_token)
|
||||
except jwtutil.InvalidTokenError as ite:
|
||||
logger.exception('Invalid token reason: %s', ite)
|
||||
raise InvalidBearerTokenException(ite)
|
||||
|
||||
kid = headers.get('kid', None)
|
||||
if kid is None:
|
||||
logger.error('Missing kid header on encoded JWT: %s', bearer_token)
|
||||
raise InvalidBearerTokenException('Missing kid header')
|
||||
|
||||
# Find the matching public key.
|
||||
public_key = instance_keys.get_service_key_public_key(kid)
|
||||
if public_key is None:
|
||||
if metric_queue is not None:
|
||||
metric_queue.invalid_instance_key_count.Inc(labelvalues=[kid])
|
||||
|
||||
logger.error('Could not find requested service key %s with encoded JWT: %s', kid, bearer_token)
|
||||
raise InvalidBearerTokenException('Unknown service key')
|
||||
|
||||
# Load the JWT returned.
|
||||
try:
|
||||
expected_issuer = instance_keys.service_name
|
||||
audience = config['SERVER_HOSTNAME']
|
||||
max_signed_s = config.get('REGISTRY_JWT_AUTH_MAX_FRESH_S', 3660)
|
||||
max_exp = jwtutil.exp_max_s_option(max_signed_s)
|
||||
payload = jwtutil.decode(bearer_token, public_key, algorithms=[ALGORITHM], audience=audience,
|
||||
issuer=expected_issuer, options=max_exp, leeway=JWT_CLOCK_SKEW_SECONDS)
|
||||
except jwtutil.InvalidTokenError as ite:
|
||||
logger.exception('Invalid token reason: %s', ite)
|
||||
raise InvalidBearerTokenException(ite)
|
||||
|
||||
if not 'sub' in payload:
|
||||
raise InvalidBearerTokenException('Missing sub field in JWT')
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def generate_bearer_token(audience, subject, context, access, lifetime_s, instance_keys):
|
||||
""" Generates a registry bearer token (without the 'Bearer ' portion) based on the given
|
||||
information.
|
||||
"""
|
||||
return _generate_jwt_object(audience, subject, context, access, lifetime_s,
|
||||
instance_keys.service_name, instance_keys.local_key_id,
|
||||
instance_keys.local_private_key)
|
||||
|
||||
|
||||
def _generate_jwt_object(audience, subject, context, access, lifetime_s, issuer, key_id,
|
||||
private_key):
|
||||
""" Generates a compact encoded JWT with the values specified. """
|
||||
token_data = {
|
||||
'iss': issuer,
|
||||
'aud': audience,
|
||||
'nbf': int(time.time()),
|
||||
'iat': int(time.time()),
|
||||
'exp': int(time.time() + lifetime_s),
|
||||
'sub': subject,
|
||||
'access': access,
|
||||
'context': context,
|
||||
}
|
||||
|
||||
token_headers = {
|
||||
'kid': key_id,
|
||||
}
|
||||
|
||||
return jwt.encode(token_data, private_key, ALGORITHM, headers=token_headers)
|
||||
|
||||
|
||||
def build_context_and_subject(auth_context=None, tuf_roots=None):
|
||||
""" Builds the custom context field for the JWT signed token and returns it,
|
||||
along with the subject for the JWT signed token. """
|
||||
# Serialize to a dictionary.
|
||||
context = auth_context.to_signed_dict() if auth_context else {}
|
||||
|
||||
# TODO: remove once Apostille has been upgraded to not use the single root.
|
||||
single_root = (tuf_roots.values()[0]
|
||||
if tuf_roots is not None and len(tuf_roots) == 1
|
||||
else DISABLED_TUF_ROOT)
|
||||
|
||||
context.update({
|
||||
CLAIM_TUF_ROOTS: tuf_roots,
|
||||
CLAIM_TUF_ROOT: single_root,
|
||||
})
|
||||
|
||||
if not auth_context or auth_context.is_anonymous:
|
||||
return (context, ANONYMOUS_SUB)
|
||||
|
||||
return (context, auth_context.authed_user.username if auth_context.authed_user else None)
|
28
util/security/secret.py
Normal file
28
util/security/secret.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
import itertools
|
||||
import uuid
|
||||
|
||||
def convert_secret_key(config_secret_key):
|
||||
""" Converts the secret key from the app config into a secret key that is usable by AES
|
||||
Cipher. """
|
||||
secret_key = None
|
||||
|
||||
# First try parsing the key as an int.
|
||||
try:
|
||||
big_int = int(config_secret_key)
|
||||
secret_key = str(bytearray.fromhex('{:02x}'.format(big_int)))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Next try parsing it as an UUID.
|
||||
if secret_key is None:
|
||||
try:
|
||||
secret_key = uuid.UUID(config_secret_key).bytes
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if secret_key is None:
|
||||
secret_key = str(bytearray(map(ord, config_secret_key)))
|
||||
|
||||
# Otherwise, use the bytes directly.
|
||||
assert len(secret_key)
|
||||
return ''.join(itertools.islice(itertools.cycle(secret_key), 32))
|
82
util/security/signing.py
Normal file
82
util/security/signing.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
import gpgme
|
||||
import os
|
||||
import features
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from StringIO import StringIO
|
||||
|
||||
class GPG2Signer(object):
|
||||
""" Helper class for signing data using GPG2. """
|
||||
def __init__(self, config, config_provider):
|
||||
if not config.get('GPG2_PRIVATE_KEY_NAME'):
|
||||
raise Exception('Missing configuration key GPG2_PRIVATE_KEY_NAME')
|
||||
|
||||
if not config.get('GPG2_PRIVATE_KEY_FILENAME'):
|
||||
raise Exception('Missing configuration key GPG2_PRIVATE_KEY_FILENAME')
|
||||
|
||||
if not config.get('GPG2_PUBLIC_KEY_FILENAME'):
|
||||
raise Exception('Missing configuration key GPG2_PUBLIC_KEY_FILENAME')
|
||||
|
||||
self._ctx = gpgme.Context()
|
||||
self._ctx.armor = True
|
||||
self._private_key_name = config['GPG2_PRIVATE_KEY_NAME']
|
||||
self._public_key_filename = config['GPG2_PUBLIC_KEY_FILENAME']
|
||||
self._config_provider = config_provider
|
||||
|
||||
if not config_provider.volume_file_exists(config['GPG2_PRIVATE_KEY_FILENAME']):
|
||||
raise Exception('Missing key file %s' % config['GPG2_PRIVATE_KEY_FILENAME'])
|
||||
|
||||
with config_provider.get_volume_file(config['GPG2_PRIVATE_KEY_FILENAME'], mode='rb') as fp:
|
||||
self._ctx.import_(fp)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return 'gpg2'
|
||||
|
||||
def open_public_key_file(self):
|
||||
return self._config_provider.get_volume_file(self._public_key_filename, mode='rb')
|
||||
|
||||
def detached_sign(self, stream):
|
||||
""" Signs the given stream, returning the signature. """
|
||||
ctx = self._ctx
|
||||
try:
|
||||
ctx.signers = [ctx.get_key(self._private_key_name)]
|
||||
except:
|
||||
raise Exception('Invalid private key name')
|
||||
|
||||
signature = StringIO()
|
||||
new_sigs = ctx.sign(stream, signature, gpgme.SIG_MODE_DETACH)
|
||||
signature.seek(0)
|
||||
return signature.getvalue()
|
||||
|
||||
|
||||
class Signer(object):
|
||||
def __init__(self, app=None, config_provider=None):
|
||||
self.app = app
|
||||
if app is not None:
|
||||
self.state = self.init_app(app, config_provider)
|
||||
else:
|
||||
self.state = None
|
||||
|
||||
def init_app(self, app, config_provider):
|
||||
preference = app.config.get('SIGNING_ENGINE', None)
|
||||
if preference is None:
|
||||
return None
|
||||
|
||||
if not features.ACI_CONVERSION:
|
||||
return None
|
||||
|
||||
try:
|
||||
return SIGNING_ENGINES[preference](app.config, config_provider)
|
||||
except:
|
||||
logger.exception('Could not initialize signing engine')
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.state, name, None)
|
||||
|
||||
|
||||
SIGNING_ENGINES = {
|
||||
'gpg2': GPG2Signer
|
||||
}
|
12
util/security/ssh.py
Normal file
12
util/security/ssh.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from Crypto.PublicKey import RSA
|
||||
|
||||
def generate_ssh_keypair():
|
||||
"""
|
||||
Generates a new 2048 bit RSA public key in OpenSSH format and private key in PEM format.
|
||||
"""
|
||||
key = RSA.generate(2048)
|
||||
public_key = key.publickey().exportKey('OpenSSH')
|
||||
private_key = key.exportKey('PEM')
|
||||
return (public_key, private_key)
|
81
util/security/ssl.py
Normal file
81
util/security/ssl.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
from fnmatch import fnmatch
|
||||
|
||||
import OpenSSL
|
||||
|
||||
class CertInvalidException(Exception):
|
||||
""" Exception raised when a certificate could not be parsed/loaded. """
|
||||
pass
|
||||
|
||||
class KeyInvalidException(Exception):
|
||||
""" Exception raised when a key could not be parsed/loaded or successfully applied to a cert. """
|
||||
pass
|
||||
|
||||
|
||||
def load_certificate(cert_contents):
|
||||
""" Loads the certificate from the given contents and returns it or raises a CertInvalidException
|
||||
on failure.
|
||||
"""
|
||||
try:
|
||||
cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, cert_contents)
|
||||
return SSLCertificate(cert)
|
||||
except OpenSSL.crypto.Error as ex:
|
||||
raise CertInvalidException(ex.args[0][0][2])
|
||||
|
||||
|
||||
_SUBJECT_ALT_NAME = 'subjectAltName'
|
||||
|
||||
class SSLCertificate(object):
|
||||
""" Helper class for easier working with SSL certificates. """
|
||||
def __init__(self, openssl_cert):
|
||||
self.openssl_cert = openssl_cert
|
||||
|
||||
def validate_private_key(self, private_key_path):
|
||||
""" Validates that the private key found at the given file path applies to this certificate.
|
||||
Raises a KeyInvalidException on failure.
|
||||
"""
|
||||
context = OpenSSL.SSL.Context(OpenSSL.SSL.TLSv1_METHOD)
|
||||
context.use_certificate(self.openssl_cert)
|
||||
|
||||
try:
|
||||
context.use_privatekey_file(private_key_path)
|
||||
context.check_privatekey()
|
||||
except OpenSSL.SSL.Error as ex:
|
||||
raise KeyInvalidException(ex.args[0][0][2])
|
||||
|
||||
def matches_name(self, check_name):
|
||||
""" Returns true if this SSL certificate matches the given DNS hostname. """
|
||||
for dns_name in self.names:
|
||||
if fnmatch(check_name, dns_name):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def expired(self):
|
||||
""" Returns whether the SSL certificate has expired. """
|
||||
return self.openssl_cert.has_expired()
|
||||
|
||||
@property
|
||||
def common_name(self):
|
||||
""" Returns the defined common name for the certificate, if any. """
|
||||
return self.openssl_cert.get_subject().commonName
|
||||
|
||||
@property
|
||||
def names(self):
|
||||
""" Returns all the DNS named to which the certificate applies. May be empty. """
|
||||
dns_names = set()
|
||||
common_name = self.common_name
|
||||
if common_name is not None:
|
||||
dns_names.add(common_name)
|
||||
|
||||
# Find the DNS extension, if any.
|
||||
for i in range(0, self.openssl_cert.get_extension_count()):
|
||||
ext = self.openssl_cert.get_extension(i)
|
||||
if ext.get_short_name() == _SUBJECT_ALT_NAME:
|
||||
value = str(ext)
|
||||
for san_name in value.split(','):
|
||||
san_name_trimmed = san_name.strip()
|
||||
if san_name_trimmed.startswith('DNS:'):
|
||||
dns_names.add(san_name_trimmed[4:])
|
||||
|
||||
return dns_names
|
0
util/security/test/__init__.py
Normal file
0
util/security/test/__init__.py
Normal file
119
util/security/test/test_jwtutil.py
Normal file
119
util/security/test/test_jwtutil.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import time
|
||||
|
||||
import pytest
|
||||
import jwt
|
||||
|
||||
from Crypto.PublicKey import RSA
|
||||
from jwkest.jwk import RSAKey
|
||||
|
||||
from util.security.jwtutil import (decode, exp_max_s_option, jwk_dict_to_public_key,
|
||||
InvalidTokenError, InvalidAlgorithmError)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def private_key():
|
||||
return RSA.generate(2048)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def private_key_pem(private_key):
|
||||
return private_key.exportKey('PEM')
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def public_key(private_key):
|
||||
return private_key.publickey().exportKey('PEM')
|
||||
|
||||
|
||||
def _token_data(audience, subject, iss, iat=None, exp=None, nbf=None):
|
||||
return {
|
||||
'iss': iss,
|
||||
'aud': audience,
|
||||
'nbf': nbf() if nbf is not None else int(time.time()),
|
||||
'iat': iat() if iat is not None else int(time.time()),
|
||||
'exp': exp() if exp is not None else int(time.time() + 3600),
|
||||
'sub': subject,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize('aud, iss, nbf, iat, exp, expected_exception', [
|
||||
pytest.param('invalidaudience', 'someissuer', None, None, None, 'Invalid audience',
|
||||
id='invalid audience'),
|
||||
pytest.param('someaudience', 'invalidissuer', None, None, None, 'Invalid issuer',
|
||||
id='invalid issuer'),
|
||||
pytest.param('someaudience', 'someissuer', lambda: time.time() + 120, None, None,
|
||||
'The token is not yet valid',
|
||||
id='invalid not before'),
|
||||
pytest.param('someaudience', 'someissuer', None, lambda: time.time() + 120, None,
|
||||
'Issued At claim',
|
||||
id='issued at in future'),
|
||||
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() - 100,
|
||||
'Signature has expired',
|
||||
id='already expired'),
|
||||
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() + 10000,
|
||||
'Token was signed for more than',
|
||||
id='expiration too far in future'),
|
||||
|
||||
pytest.param('someaudience', 'someissuer', lambda: time.time() + 10, None, None,
|
||||
None,
|
||||
id='not before in future by within leeway'),
|
||||
pytest.param('someaudience', 'someissuer', None, lambda: time.time() + 10, None,
|
||||
None,
|
||||
id='issued at in future but within leeway'),
|
||||
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() - 10,
|
||||
None,
|
||||
id='expiration in past but within leeway'),
|
||||
])
|
||||
def test_decode_jwt_validation(aud, iss, nbf, iat, exp, expected_exception, private_key_pem,
|
||||
public_key):
|
||||
token = jwt.encode(_token_data(aud, 'subject', iss, iat, exp, nbf), private_key_pem, 'RS256')
|
||||
|
||||
if expected_exception is not None:
|
||||
with pytest.raises(InvalidTokenError) as ite:
|
||||
max_exp = exp_max_s_option(3600)
|
||||
decode(token, public_key, algorithms=['RS256'], audience='someaudience',
|
||||
issuer='someissuer', options=max_exp, leeway=60)
|
||||
assert ite.match(expected_exception)
|
||||
else:
|
||||
max_exp = exp_max_s_option(3600)
|
||||
decode(token, public_key, algorithms=['RS256'], audience='someaudience',
|
||||
issuer='someissuer', options=max_exp, leeway=60)
|
||||
|
||||
|
||||
def test_decode_jwt_invalid_key(private_key_pem):
|
||||
# Encode with the test private key.
|
||||
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
|
||||
|
||||
# Try to decode with a different public key.
|
||||
another_public_key = RSA.generate(2048).publickey().exportKey('PEM')
|
||||
with pytest.raises(InvalidTokenError) as ite:
|
||||
max_exp = exp_max_s_option(3600)
|
||||
decode(token, another_public_key, algorithms=['RS256'], audience='aud',
|
||||
issuer='someissuer', options=max_exp, leeway=60)
|
||||
assert ite.match('Signature verification failed')
|
||||
|
||||
|
||||
def test_decode_jwt_invalid_algorithm(private_key_pem, public_key):
|
||||
# Encode with the test private key.
|
||||
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
|
||||
|
||||
# Attempt to decode but only with a different algorithm than that used.
|
||||
with pytest.raises(InvalidAlgorithmError) as ite:
|
||||
max_exp = exp_max_s_option(3600)
|
||||
decode(token, public_key, algorithms=['ES256'], audience='aud',
|
||||
issuer='someissuer', options=max_exp, leeway=60)
|
||||
assert ite.match('are not whitelisted')
|
||||
|
||||
|
||||
def test_jwk_dict_to_public_key(private_key, private_key_pem):
|
||||
public_key = private_key.publickey()
|
||||
jwk = RSAKey(key=private_key.publickey()).serialize()
|
||||
converted = jwk_dict_to_public_key(jwk)
|
||||
|
||||
# Encode with the test private key.
|
||||
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
|
||||
|
||||
# Decode with the converted key.
|
||||
max_exp = exp_max_s_option(3600)
|
||||
decode(token, converted, algorithms=['RS256'], audience='aud',
|
||||
issuer='someissuer', options=max_exp, leeway=60)
|
116
util/security/test/test_ssl_util.py
Normal file
116
util/security/test/test_ssl_util.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import pytest
|
||||
|
||||
from OpenSSL import crypto
|
||||
|
||||
from util.security.ssl import load_certificate, CertInvalidException, KeyInvalidException
|
||||
|
||||
def generate_test_cert(hostname='somehostname', san_list=None, expires=1000000):
|
||||
""" Generates a test SSL certificate and returns the certificate data and private key data. """
|
||||
|
||||
# Based on: http://blog.richardknop.com/2012/08/create-a-self-signed-x509-certificate-in-python/
|
||||
# Create a key pair.
|
||||
k = crypto.PKey()
|
||||
k.generate_key(crypto.TYPE_RSA, 1024)
|
||||
|
||||
# Create a self-signed cert.
|
||||
cert = crypto.X509()
|
||||
cert.get_subject().CN = hostname
|
||||
|
||||
# Add the subjectAltNames (if necessary).
|
||||
if san_list is not None:
|
||||
cert.add_extensions([crypto.X509Extension("subjectAltName", False, ", ".join(san_list))])
|
||||
|
||||
cert.set_serial_number(1000)
|
||||
cert.gmtime_adj_notBefore(0)
|
||||
cert.gmtime_adj_notAfter(expires)
|
||||
cert.set_issuer(cert.get_subject())
|
||||
|
||||
cert.set_pubkey(k)
|
||||
cert.sign(k, 'sha1')
|
||||
|
||||
# Dump the certificate and private key in PEM format.
|
||||
cert_data = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
|
||||
key_data = crypto.dump_privatekey(crypto.FILETYPE_PEM, k)
|
||||
|
||||
return (cert_data, key_data)
|
||||
|
||||
|
||||
def test_load_certificate():
|
||||
# Try loading an invalid certificate.
|
||||
with pytest.raises(CertInvalidException):
|
||||
load_certificate('someinvalidcontents')
|
||||
|
||||
# Load a valid certificate.
|
||||
(public_key_data, _) = generate_test_cert()
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
assert not cert.expired
|
||||
assert cert.names == set(['somehostname'])
|
||||
assert cert.matches_name('somehostname')
|
||||
|
||||
def test_expired_certificate():
|
||||
(public_key_data, _) = generate_test_cert(expires=-100)
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
assert cert.expired
|
||||
|
||||
def test_hostnames():
|
||||
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:bar', 'DNS:baz'])
|
||||
cert = load_certificate(public_key_data)
|
||||
assert cert.names == set(['foo', 'bar', 'baz'])
|
||||
|
||||
for name in cert.names:
|
||||
assert cert.matches_name(name)
|
||||
|
||||
def test_wildcard_hostnames():
|
||||
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['DNS:*.bar'])
|
||||
cert = load_certificate(public_key_data)
|
||||
assert cert.names == set(['foo', '*.bar'])
|
||||
|
||||
for name in cert.names:
|
||||
assert cert.matches_name(name)
|
||||
|
||||
assert cert.matches_name('something.bar')
|
||||
assert cert.matches_name('somethingelse.bar')
|
||||
assert cert.matches_name('cool.bar')
|
||||
assert not cert.matches_name('*')
|
||||
|
||||
def test_nondns_hostnames():
|
||||
(public_key_data, _) = generate_test_cert(hostname='foo', san_list=['URI:yarg'])
|
||||
cert = load_certificate(public_key_data)
|
||||
assert cert.names == set(['foo'])
|
||||
|
||||
def test_validate_private_key():
|
||||
(public_key_data, private_key_data) = generate_test_cert()
|
||||
|
||||
private_key = NamedTemporaryFile(delete=True)
|
||||
private_key.write(private_key_data)
|
||||
private_key.seek(0)
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
cert.validate_private_key(private_key.name)
|
||||
|
||||
def test_invalid_private_key():
|
||||
(public_key_data, _) = generate_test_cert()
|
||||
|
||||
private_key = NamedTemporaryFile(delete=True)
|
||||
private_key.write('somerandomdata')
|
||||
private_key.seek(0)
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
with pytest.raises(KeyInvalidException):
|
||||
cert.validate_private_key(private_key.name)
|
||||
|
||||
def test_mismatch_private_key():
|
||||
(public_key_data, _) = generate_test_cert()
|
||||
(_, private_key_data) = generate_test_cert()
|
||||
|
||||
private_key = NamedTemporaryFile(delete=True)
|
||||
private_key.write(private_key_data)
|
||||
private_key.seek(0)
|
||||
|
||||
cert = load_certificate(public_key_data)
|
||||
with pytest.raises(KeyInvalidException):
|
||||
cert.validate_private_key(private_key.name)
|
32
util/security/token.py
Normal file
32
util/security/token.py
Normal file
|
@ -0,0 +1,32 @@
|
|||
from collections import namedtuple
|
||||
|
||||
import base64
|
||||
|
||||
DELIMITER = ':'
|
||||
DecodedToken = namedtuple('DecodedToken', ['public_code', 'private_token'])
|
||||
|
||||
def encode_public_private_token(public_code, private_token, allow_public_only=False):
|
||||
# NOTE: This is for legacy tokens where the private token part is None. We should remove this
|
||||
# once older installations have been fully converted over (if at all possible).
|
||||
if private_token is None:
|
||||
assert allow_public_only
|
||||
return public_code
|
||||
|
||||
assert isinstance(private_token, basestring)
|
||||
return base64.b64encode('%s%s%s' % (public_code, DELIMITER, private_token))
|
||||
|
||||
|
||||
def decode_public_private_token(encoded, allow_public_only=False):
|
||||
try:
|
||||
decoded = base64.b64decode(encoded)
|
||||
except (ValueError, TypeError):
|
||||
if not allow_public_only:
|
||||
return None
|
||||
|
||||
return DecodedToken(encoded, None)
|
||||
|
||||
parts = decoded.split(DELIMITER, 2)
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
|
||||
return DecodedToken(*parts)
|
Reference in a new issue