Use the instance service key for registry JWT signing

This commit is contained in:
Joseph Schorr 2016-05-31 16:48:19 -04:00
parent a4aa5cc02a
commit 8887f09ba8
26 changed files with 457 additions and 278 deletions

55
util/expiresdict.py Normal file
View file

@ -0,0 +1,55 @@
from datetime import datetime
class ExpiresEntry(object):
""" A single entry under a ExpiresDict. """
def __init__(self, value, expires=None):
self.value = value
self._expiration = expires
@property
def expired(self):
if self._expiration is None:
return False
return datetime.now() >= self._expiration
class ExpiresDict(object):
""" ExpiresDict defines a dictionary-like class whose keys have expiration. The rebuilder is
a function that returns the full contents of the cached dictionary as a dict of the keys
and whose values are TTLEntry's.
"""
def __init__(self, rebuilder):
self._rebuilder = rebuilder
self._items = {}
def __getitem__(self, key):
found = self.get(key)
if found is None:
raise KeyError
return found
def get(self, key, default_value=None):
# Check the cache first. If the key is found and it has not yet expired,
# return it.
found = self._items.get(key)
if found is not None and not found.expired:
return found.value
# Otherwise the key has expired or was not found. Rebuild the cache and check it again.
self._rebuild()
found = self._items.get(key)
if found is None:
return default_value
return found.value
def __contains__(self, key):
return self.get(key) is not None
def _rebuild(self):
self._items = self._rebuilder()
def set(self, key, value, expires=None):
self._items[key] = ExpiresEntry(value, expires=expires)

View file

@ -7,15 +7,12 @@ import urllib
from cachetools import lru_cache
from app import app
from app import app, instance_keys
ANNOUNCE_URL = app.config.get('BITTORRENT_ANNOUNCE_URL')
PRIVATE_KEY_LOCATION = app.config.get('INSTANCE_SERVICE_KEY_LOCATION')
FILENAME_PEPPER = app.config.get('BITTORRENT_FILENAME_PEPPER')
REGISTRY_TITLE = app.config.get('REGISTRY_TITLE')
JWT_ISSUER = app.config.get('JWT_AUTH_TOKEN_ISSUER')
ANNOUNCE_URL = app.config['BITTORRENT_ANNOUNCE_URL']
FILENAME_PEPPER = app.config['BITTORRENT_FILENAME_PEPPER']
REGISTRY_TITLE = app.config['REGISTRY_TITLE']
@lru_cache(maxsize=1)
def _load_private_key(private_key_file_path):
@ -24,13 +21,12 @@ def _load_private_key(private_key_file_path):
def _torrent_jwt(info_dict):
token_data = {
'iss': JWT_ISSUER,
'iss': instance_keys.service_name,
'aud': ANNOUNCE_URL,
'infohash': _infohash(info_dict),
}
private_key = _load_private_key(PRIVATE_KEY_LOCATION)
return jwt.encode(token_data, private_key, 'RS256')
return jwt.encode(token_data, instance_keys.local_private_key, 'RS256')
def _infohash(infodict):
digest = hashlib.sha1()

View file

@ -8,7 +8,8 @@ from data.database import CloseForLongOperation
from data import model
from data.model.storage import get_storage_locations
from util.secscan.validator import SecurityConfigValidator
from util.security.registry_jwt import generate_jwt_object, build_context_and_subject
from util.security.instancekeys import InstanceKeys
from util.security.registry_jwt import generate_bearer_token, build_context_and_subject
from util import get_app_url
@ -43,6 +44,7 @@ class SecurityScannerAPI(object):
self._app = app
self._config = config
self._instance_keys = InstanceKeys(app)
self._client = client or config['HTTPCLIENT']
self._storage = storage
self._default_storage_locations = config['DISTRIBUTED_STORAGE_PREFERENCE']
@ -80,9 +82,10 @@ class SecurityScannerAPI(object):
'name': repository_and_namespace,
'actions': ['pull'],
}]
auth_jwt = generate_jwt_object(audience, subject, context, access, TOKEN_VALIDITY_LIFETIME_S,
self._config)
auth_header = 'Bearer {}'.format(auth_jwt)
auth_token = generate_bearer_token(audience, subject, context, access,
TOKEN_VALIDITY_LIFETIME_S, self._instance_keys)
auth_header = 'Bearer ' + auth_token
with self._app.test_request_context('/'):
relative_layer_url = url_for('v2.download_blob', repository=repository_and_namespace,

View file

@ -0,0 +1,81 @@
from cachetools import lru_cache
from data import model
from util.expiresdict import ExpiresDict, ExpiresEntry
from util.security import jwtutil
class InstanceKeys(object):
""" InstanceKeys defines a helper class for interacting with the Quay instance service keys
used for JWT signing of registry tokens as well as requests from Quay to other services
such as Clair. Each container will have a single registered instance key.
"""
def __init__(self, app):
self.app = app
self.instance_keys = ExpiresDict(self._load_instance_keys)
self.public_keys = {}
def clear_cache(self):
""" Clears the cache of instance keys. """
self.instance_keys = ExpiresDict(self._load_instance_keys)
self.public_keys = {}
def _load_instance_keys(self):
# Load all the instance keys.
keys = {}
for key in model.service_keys.list_service_keys(self.service_name):
keys[key.kid] = ExpiresEntry(key, key.expiration_date)
# Remove any expired or deleted keys from the public keys cache.
for key in self.public_keys:
if key not in keys:
self.public_keys.pop(key)
return keys
@property
def service_name(self):
""" Returns the name of the instance key's service (i.e. 'quay'). """
return self.app.config['INSTANCE_SERVICE_KEY_SERVICE']
@property
def service_key_expiration(self):
""" Returns the defined expiration for instance service keys, in minutes. """
return self.app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120)
@property
@lru_cache(maxsize=1)
def local_key_id(self):
""" Returns the ID of the local instance service key. """
return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'])
@property
@lru_cache(maxsize=1)
def local_private_key(self):
""" Returns the private key of the local instance service key. """
return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_LOCATION'])
def get_service_key_public_key(self, kid):
""" Returns the public key associated with the given instance service key or None if none. """
# Note: We do the lookup via instance_keys *first* to ensure that if a key has expired, we
# don't use the entry in the public key cache.
service_key = self.instance_keys.get(kid)
if service_key is None:
# Remove the kid from the cache just to be sure.
self.public_keys.pop(kid, None)
return None
public_key = self.public_keys.get(kid)
if public_key is not None:
return public_key
# Convert the JWK into a public key and cache it (since the conversion can take > 200ms).
public_key = jwtutil.jwk_dict_to_public_key(service_key.jwk)
self.public_keys[kid] = public_key
return public_key
def _load_file_contents(path):
""" Returns the contents of the specified file path. """
with open(path) as f:
return f.read()

View file

@ -1,3 +1,5 @@
import re
from datetime import datetime, timedelta
from jwt import PyJWT
from jwt.exceptions import (
@ -5,8 +7,19 @@ from jwt.exceptions import (
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
)
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from jwkest.jwk import keyrep, RSAKey, ECKey
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
class StrictJWT(PyJWT):
""" StrictJWT defines a JWT decoder with extra checks. """
@staticmethod
def _get_default_options():
# Weird syntax to call super on a staticmethod
@ -53,3 +66,16 @@ def exp_max_s_option(max_exp_s):
decode = StrictJWT().decode
def jwk_dict_to_public_key(jwk):
""" Converts the specified JWK into a public key. """
jwkest_key = keyrep(jwk)
if isinstance(jwkest_key, RSAKey):
pycrypto_key = jwkest_key.key
return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend())
elif isinstance(jwkest_key, ECKey):
x, y = jwkest_key.get_key()
return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend())
raise Exception('Unsupported kind of JWK: %s', str(type(jwkest_key)))

View file

@ -1,17 +1,80 @@
import time
import jwt
import logging
from cachetools import lru_cache
from util.security import jwtutil
logger = logging.getLogger(__name__)
ANONYMOUS_SUB = '(anonymous)'
ALGORITHM = 'RS256'
def generate_jwt_object(audience, subject, context, access, lifetime_s, app_config):
""" Generates a compact encoded JWT with the values specified.
class InvalidBearerTokenException(Exception):
pass
def decode_bearer_token(bearer_token, instance_keys):
""" decode_bearer_token decodes the given bearer token that contains both a Key ID as well as the
encoded JWT and returns the decoded and validated JWT. On any error, raises an
InvalidBearerTokenException with the reason for failure.
"""
app_config = instance_keys.app.config
# Extract the jwt token from the header
match = jwtutil.TOKEN_REGEX.match(bearer_token)
if match is None:
raise InvalidBearerTokenException('Invalid bearer token format')
encoded_jwt = match.group(1)
logger.debug('encoded JWT: %s', encoded_jwt)
# Decode the key ID.
headers = jwt.get_unverified_header(encoded_jwt)
kid = headers.get('kid', None)
if kid is None:
logger.error('Missing kid header on encoded JWT: %s', encoded_jwt)
raise InvalidBearerTokenException('Missing kid header')
# Find the matching public key.
public_key = instance_keys.get_service_key_public_key(kid)
if public_key is None:
logger.error('Could not find requested service key %s', kid)
raise InvalidBearerTokenException('Unknown service key')
# Load the JWT returned.
try:
expected_issuer = instance_keys.service_name
audience = app_config['SERVER_HOSTNAME']
max_signed_s = app_config.get('REGISTRY_JWT_AUTH_MAX_FRESH_S', 3660)
max_exp = jwtutil.exp_max_s_option(max_signed_s)
payload = jwtutil.decode(encoded_jwt, public_key, algorithms=[ALGORITHM], audience=audience,
issuer=expected_issuer, options=max_exp)
except jwtutil.InvalidTokenError as ite:
logger.exception('Invalid token reason: %s', ite)
raise InvalidBearerTokenException(ite)
if not 'sub' in payload:
raise InvalidBearerTokenException('Missing sub field in JWT')
return payload
def generate_bearer_token(audience, subject, context, access, lifetime_s, instance_keys):
""" Generates a registry bearer token (without the 'Bearer ' portion) based on the given
information.
"""
return _generate_jwt_object(audience, subject, context, access, lifetime_s,
instance_keys.service_name, instance_keys.local_key_id,
instance_keys.local_private_key)
def _generate_jwt_object(audience, subject, context, access, lifetime_s, issuer, key_id,
private_key):
""" Generates a compact encoded JWT with the values specified. """
token_data = {
'iss': app_config['JWT_AUTH_TOKEN_ISSUER'],
'iss': issuer,
'aud': audience,
'nbf': int(time.time()),
'iat': int(time.time()),
@ -21,15 +84,11 @@ def generate_jwt_object(audience, subject, context, access, lifetime_s, app_conf
'context': context,
}
certificate = _load_certificate_bytes(app_config['JWT_AUTH_CERTIFICATE_PATH'])
token_headers = {
'x5c': [certificate],
'kid': key_id,
}
private_key = _load_private_key(app_config['JWT_AUTH_PRIVATE_KEY_PATH'])
return jwt.encode(token_data, private_key, 'RS256', headers=token_headers)
return jwt.encode(token_data, private_key, ALGORITHM, headers=token_headers)
def build_context_and_subject(user, token, oauthtoken):
@ -64,14 +123,3 @@ def build_context_and_subject(user, token, oauthtoken):
return (context, ANONYMOUS_SUB)
@lru_cache(maxsize=1)
def _load_certificate_bytes(certificate_file_path):
with open(certificate_file_path) as cert_file:
cert_lines = cert_file.readlines()[1:-1]
return ''.join([cert_line.rstrip('\n') for cert_line in cert_lines])
@lru_cache(maxsize=1)
def _load_private_key(private_key_file_path):
with open(private_key_file_path) as private_key_file:
return private_key_file.read()