Use the instance service key for registry JWT signing

This commit is contained in:
Joseph Schorr 2016-05-31 16:48:19 -04:00
parent a4aa5cc02a
commit 8887f09ba8
26 changed files with 457 additions and 278 deletions

View file

@ -0,0 +1,81 @@
from cachetools import lru_cache
from data import model
from util.expiresdict import ExpiresDict, ExpiresEntry
from util.security import jwtutil
class InstanceKeys(object):
""" InstanceKeys defines a helper class for interacting with the Quay instance service keys
used for JWT signing of registry tokens as well as requests from Quay to other services
such as Clair. Each container will have a single registered instance key.
"""
def __init__(self, app):
self.app = app
self.instance_keys = ExpiresDict(self._load_instance_keys)
self.public_keys = {}
def clear_cache(self):
""" Clears the cache of instance keys. """
self.instance_keys = ExpiresDict(self._load_instance_keys)
self.public_keys = {}
def _load_instance_keys(self):
# Load all the instance keys.
keys = {}
for key in model.service_keys.list_service_keys(self.service_name):
keys[key.kid] = ExpiresEntry(key, key.expiration_date)
# Remove any expired or deleted keys from the public keys cache.
for key in self.public_keys:
if key not in keys:
self.public_keys.pop(key)
return keys
@property
def service_name(self):
""" Returns the name of the instance key's service (i.e. 'quay'). """
return self.app.config['INSTANCE_SERVICE_KEY_SERVICE']
@property
def service_key_expiration(self):
""" Returns the defined expiration for instance service keys, in minutes. """
return self.app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120)
@property
@lru_cache(maxsize=1)
def local_key_id(self):
""" Returns the ID of the local instance service key. """
return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'])
@property
@lru_cache(maxsize=1)
def local_private_key(self):
""" Returns the private key of the local instance service key. """
return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_LOCATION'])
def get_service_key_public_key(self, kid):
""" Returns the public key associated with the given instance service key or None if none. """
# Note: We do the lookup via instance_keys *first* to ensure that if a key has expired, we
# don't use the entry in the public key cache.
service_key = self.instance_keys.get(kid)
if service_key is None:
# Remove the kid from the cache just to be sure.
self.public_keys.pop(kid, None)
return None
public_key = self.public_keys.get(kid)
if public_key is not None:
return public_key
# Convert the JWK into a public key and cache it (since the conversion can take > 200ms).
public_key = jwtutil.jwk_dict_to_public_key(service_key.jwk)
self.public_keys[kid] = public_key
return public_key
def _load_file_contents(path):
""" Returns the contents of the specified file path. """
with open(path) as f:
return f.read()

View file

@ -1,3 +1,5 @@
import re
from datetime import datetime, timedelta
from jwt import PyJWT
from jwt.exceptions import (
@ -5,8 +7,19 @@ from jwt.exceptions import (
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
)
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from jwkest.jwk import keyrep, RSAKey, ECKey
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
class StrictJWT(PyJWT):
""" StrictJWT defines a JWT decoder with extra checks. """
@staticmethod
def _get_default_options():
# Weird syntax to call super on a staticmethod
@ -53,3 +66,16 @@ def exp_max_s_option(max_exp_s):
decode = StrictJWT().decode
def jwk_dict_to_public_key(jwk):
""" Converts the specified JWK into a public key. """
jwkest_key = keyrep(jwk)
if isinstance(jwkest_key, RSAKey):
pycrypto_key = jwkest_key.key
return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend())
elif isinstance(jwkest_key, ECKey):
x, y = jwkest_key.get_key()
return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend())
raise Exception('Unsupported kind of JWK: %s', str(type(jwkest_key)))

View file

@ -1,17 +1,80 @@
import time
import jwt
import logging
from cachetools import lru_cache
from util.security import jwtutil
logger = logging.getLogger(__name__)
ANONYMOUS_SUB = '(anonymous)'
ALGORITHM = 'RS256'
def generate_jwt_object(audience, subject, context, access, lifetime_s, app_config):
""" Generates a compact encoded JWT with the values specified.
class InvalidBearerTokenException(Exception):
pass
def decode_bearer_token(bearer_token, instance_keys):
""" decode_bearer_token decodes the given bearer token that contains both a Key ID as well as the
encoded JWT and returns the decoded and validated JWT. On any error, raises an
InvalidBearerTokenException with the reason for failure.
"""
app_config = instance_keys.app.config
# Extract the jwt token from the header
match = jwtutil.TOKEN_REGEX.match(bearer_token)
if match is None:
raise InvalidBearerTokenException('Invalid bearer token format')
encoded_jwt = match.group(1)
logger.debug('encoded JWT: %s', encoded_jwt)
# Decode the key ID.
headers = jwt.get_unverified_header(encoded_jwt)
kid = headers.get('kid', None)
if kid is None:
logger.error('Missing kid header on encoded JWT: %s', encoded_jwt)
raise InvalidBearerTokenException('Missing kid header')
# Find the matching public key.
public_key = instance_keys.get_service_key_public_key(kid)
if public_key is None:
logger.error('Could not find requested service key %s', kid)
raise InvalidBearerTokenException('Unknown service key')
# Load the JWT returned.
try:
expected_issuer = instance_keys.service_name
audience = app_config['SERVER_HOSTNAME']
max_signed_s = app_config.get('REGISTRY_JWT_AUTH_MAX_FRESH_S', 3660)
max_exp = jwtutil.exp_max_s_option(max_signed_s)
payload = jwtutil.decode(encoded_jwt, public_key, algorithms=[ALGORITHM], audience=audience,
issuer=expected_issuer, options=max_exp)
except jwtutil.InvalidTokenError as ite:
logger.exception('Invalid token reason: %s', ite)
raise InvalidBearerTokenException(ite)
if not 'sub' in payload:
raise InvalidBearerTokenException('Missing sub field in JWT')
return payload
def generate_bearer_token(audience, subject, context, access, lifetime_s, instance_keys):
""" Generates a registry bearer token (without the 'Bearer ' portion) based on the given
information.
"""
return _generate_jwt_object(audience, subject, context, access, lifetime_s,
instance_keys.service_name, instance_keys.local_key_id,
instance_keys.local_private_key)
def _generate_jwt_object(audience, subject, context, access, lifetime_s, issuer, key_id,
private_key):
""" Generates a compact encoded JWT with the values specified. """
token_data = {
'iss': app_config['JWT_AUTH_TOKEN_ISSUER'],
'iss': issuer,
'aud': audience,
'nbf': int(time.time()),
'iat': int(time.time()),
@ -21,15 +84,11 @@ def generate_jwt_object(audience, subject, context, access, lifetime_s, app_conf
'context': context,
}
certificate = _load_certificate_bytes(app_config['JWT_AUTH_CERTIFICATE_PATH'])
token_headers = {
'x5c': [certificate],
'kid': key_id,
}
private_key = _load_private_key(app_config['JWT_AUTH_PRIVATE_KEY_PATH'])
return jwt.encode(token_data, private_key, 'RS256', headers=token_headers)
return jwt.encode(token_data, private_key, ALGORITHM, headers=token_headers)
def build_context_and_subject(user, token, oauthtoken):
@ -64,14 +123,3 @@ def build_context_and_subject(user, token, oauthtoken):
return (context, ANONYMOUS_SUB)
@lru_cache(maxsize=1)
def _load_certificate_bytes(certificate_file_path):
with open(certificate_file_path) as cert_file:
cert_lines = cert_file.readlines()[1:-1]
return ''.join([cert_line.rstrip('\n') for cert_line in cert_lines])
@lru_cache(maxsize=1)
def _load_private_key(private_key_file_path):
with open(private_key_file_path) as private_key_file:
return private_key_file.read()