import logging from datetime import datetime, timedelta from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers from flask import Blueprint, jsonify, abort, request, make_response from jwkest.jwk import keyrep, RSAKey, ECKey from jwt import get_unverified_header import data.model import data.model.service_keys from data.model.log import log_action from app import app from auth.registry_jwt_auth import TOKEN_REGEX from util.security import strictjwt logger = logging.getLogger(__name__) key_server = Blueprint('key_server', __name__) JWT_HEADER_NAME = 'Authorization' JWT_AUDIENCE = app.config['PREFERRED_URL_SCHEME'] + '://' + app.config['SERVER_HOSTNAME'] def _validate_jwk(jwk): if 'kty' not in jwk: abort(400) if jwk['kty'] == 'EC': if 'x' not in jwk or 'y' not in jwk: abort(400) elif jwk['kty'] == 'RSA': if 'e' not in jwk or 'n' not in jwk: abort(400) else: abort(400) def _jwk_dict_to_public_key(jwk): jwkest_key = keyrep(jwk) if isinstance(jwkest_key, RSAKey): pycrypto_key = jwkest_key.key return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend()) elif isinstance(jwkest_key, ECKey): x, y = jwkest_key.get_key() return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend()) def _validate_jwt(encoded_jwt, jwk, service): public_key = _jwk_dict_to_public_key(jwk) try: strictjwt.decode(encoded_jwt, public_key, algorithms=['RS256'], audience=JWT_AUDIENCE, issuer=service) except strictjwt.InvalidTokenError: logger.exception('JWT validation failure') abort(400) def _signer_kid(encoded_jwt): headers = get_unverified_header(encoded_jwt) return headers.get('kid', None) def _signer_key(service, signer_kid): try: return data.model.service_keys.get_service_key(signer_kid, service=service) except data.model.ServiceKeyDoesNotExist: abort(403) @key_server.route('/services//keys', methods=['GET']) def list_service_keys(service): keys = data.model.service_keys.list_service_keys(service) return jsonify({'keys': [key.jwk for key in keys]}) @key_server.route('/services//keys/', methods=['GET']) def get_service_key(service, kid): try: key = data.model.service_keys.get_service_key(kid) except data.model.ServiceKeyDoesNotExist: abort(404) if key.approval is None: abort(409) if key.expiration_date is not None and key.expiration_date <= datetime.utcnow(): abort(403) resp = jsonify(key.jwk) lifetime = min(timedelta(days=1), ((key.expiration_date or datetime.max) - datetime.utcnow())) resp.cache_control.max_age = max(0, lifetime.total_seconds()) return resp @key_server.route('/services//keys/', methods=['PUT']) def put_service_key(service, kid): metadata = {'ip': request.remote_addr} rotation_duration = request.args.get('rotation', None) expiration_date = request.args.get('expiration', None) if expiration_date is not None: try: expiration_date = datetime.utcfromtimestamp(float(expiration_date)) except ValueError: logger.exception('Error parsing expiration date on key') abort(400) try: jwk = request.get_json() except ValueError: logger.exception('Error parsing JWK') abort(400) jwt_header = request.headers.get(JWT_HEADER_NAME, '') match = TOKEN_REGEX.match(jwt_header) if match is None: logger.error('Could not find matching bearer token') abort(400) encoded_jwt = match.group(1) _validate_jwk(jwk) signer_kid = _signer_kid(encoded_jwt) if kid == signer_kid or signer_kid is None: # The key is self-signed. Create a new instance and await approval. _validate_jwt(encoded_jwt, jwk, service) data.model.service_keys.create_service_key('', kid, service, jwk, metadata, expiration_date, rotation_duration=rotation_duration) key_log_metadata = { 'kid': kid, 'preshared': False, 'service': service, 'name': '', 'expiration_date': expiration_date, 'user_agent': request.headers.get('User-Agent'), 'ip': request.remote_addr, } log_action('service_key_create', None, metadata=key_log_metadata, ip=request.remote_addr) return make_response('', 202) metadata.update({'created_by': 'Key Rotation'}) signer_key = _signer_key(service, signer_kid) signer_jwk = signer_key.jwk if signer_key.service != service: abort(403) _validate_jwt(encoded_jwt, signer_jwk, service) try: data.model.service_keys.replace_service_key(signer_key.kid, kid, jwk, metadata, expiration_date) except data.model.ServiceKeyDoesNotExist: abort(404) key_log_metadata = { 'kid': kid, 'signer_kid': signer_key.kid, 'service': service, 'name': signer_key.name, 'expiration_date': expiration_date, 'user_agent': request.headers.get('User-Agent'), 'ip': request.remote_addr, } log_action('service_key_rotate', None, metadata=key_log_metadata, ip=request.remote_addr) return make_response('', 200) @key_server.route('/services//keys/', methods=['DELETE']) def delete_service_key(service, kid): jwt_header = request.headers.get(JWT_HEADER_NAME, '') match = TOKEN_REGEX.match(jwt_header) if match is None: abort(400) encoded_jwt = match.group(1) signer_kid = _signer_kid(encoded_jwt) signer_key = _signer_key(service, signer_kid) self_signed = kid == signer_kid or signer_kid == '' approved_key_for_service = signer_key.approval is not None if self_signed or approved_key_for_service: _validate_jwt(encoded_jwt, signer_key.jwk, service) try: data.model.service_keys.delete_service_key(kid) except data.model.ServiceKeyDoesNotExist: abort(404) key_log_metadata = { 'kid': kid, 'signer_kid': signer_key.kid, 'service': service, 'name': signer_key.name, 'user_agent': request.headers.get('User-Agent'), 'ip': request.remote_addr, } log_action('service_key_delete', None, metadata=key_log_metadata, ip=request.remote_addr) return make_response('', 204) abort(403)