keyserver: get signer kid from unverified headers

This commit is contained in:
Jimmy Zelinskie 2016-04-11 12:04:42 -04:00 committed by Jimmy Zelinskie
parent 08017c5111
commit 9f4a4092da

View file

@ -2,13 +2,12 @@ import logging
from datetime import datetime from datetime import datetime
import jwt
from flask import Blueprint, jsonify, abort, request, make_response
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
from flask import Blueprint, jsonify, abort, request, make_response
from jwkest.jwk import keyrep, RSAKey, ECKey from jwkest.jwk import keyrep, RSAKey, ECKey
from jwt import get_unverified_header
import data.model import data.model
import data.model.service_keys import data.model.service_keys
@ -62,9 +61,8 @@ def _validate_jwt(encoded_jwt, jwk, service):
def _signer_kid(encoded_jwt): def _signer_kid(encoded_jwt):
decoded_jwt = jwt.decode(encoded_jwt, verify=False) headers = get_unverified_header(encoded_jwt)
logger.debug(decoded_jwt) return headers.get('kid', None)
return decoded_jwt.get('kid', None)
def _signer_key(service, signer_kid): def _signer_key(service, signer_kid):
@ -82,7 +80,6 @@ def list_service_keys(service):
@key_server.route('/services/<service>/keys/<kid>', methods=['GET']) @key_server.route('/services/<service>/keys/<kid>', methods=['GET'])
def get_service_key(service, kid): def get_service_key(service, kid):
logger.debug(kid)
try: try:
key = data.model.service_keys.get_service_key(kid) key = data.model.service_keys.get_service_key(kid)
except data.model.ServiceKeyDoesNotExist: except data.model.ServiceKeyDoesNotExist:
@ -116,8 +113,6 @@ def put_service_key(service, kid):
logger.exception('Error parsing JWK') logger.exception('Error parsing JWK')
abort(400) abort(400)
logger.debug(jwk)
jwt_header = request.headers.get(JWT_HEADER_NAME, '') jwt_header = request.headers.get(JWT_HEADER_NAME, '')
match = TOKEN_REGEX.match(jwt_header) match = TOKEN_REGEX.match(jwt_header)
if match is None: if match is None: