diff --git a/util/security/jwtutil.py b/util/security/jwtutil.py index 7e2bc6cee..be70aa8a1 100644 --- a/util/security/jwtutil.py +++ b/util/security/jwtutil.py @@ -5,7 +5,8 @@ from datetime import datetime, timedelta from jwt import PyJWT from jwt.exceptions import ( InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError, - ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError + ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError, + InvalidAlgorithmError ) from cryptography.hazmat.backends import default_backend @@ -17,14 +18,19 @@ from jwkest.jwk import keyrep, RSAKey, ECKey # TOKEN_REGEX defines a regular expression for matching JWT bearer tokens. TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z') +# ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD +# `none` here! +ALGORITHM_WHITELIST = [ + 'rs256' +] -class StrictJWT(PyJWT): - """ StrictJWT defines a JWT decoder with extra checks. """ +class _StrictJWT(PyJWT): + """ _StrictJWT defines a JWT decoder with extra checks. """ @staticmethod def _get_default_options(): # Weird syntax to call super on a staticmethod - defaults = super(StrictJWT, StrictJWT)._get_default_options() + defaults = super(_StrictJWT, _StrictJWT)._get_default_options() defaults.update({ 'require_exp': True, 'require_iat': True, @@ -41,14 +47,14 @@ class StrictJWT(PyJWT): options['verify_exp'] = True # Do all of the other checks - super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs) + super(_StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs) now = timegm(datetime.utcnow().utctimetuple()) self._reject_future_iat(payload, now, leeway) if 'exp' in payload and options.get('exp_max_s') is not None: # Validate that the expiration was not more than exp_max_s seconds after the issue time - # or in the absense of an issue time, more than exp_max_s in the future from now + # or in the absence of an issue time, more than exp_max_s in the future from now # This will work because the parent method already checked the type of exp expiration = datetime.utcfromtimestamp(int(payload['exp'])) @@ -73,15 +79,30 @@ class StrictJWT(PyJWT): ' the future.') +def decode(jwt, key='', verify=True, algorithms=None, options=None, + **kwargs): + """ Decodes a JWT. """ + if not algorithms: + raise InvalidAlgorithmError('algorithms must be specified') + + normalized = set([a.lower() for a in algorithms]) + if 'none' in normalized: + raise InvalidAlgorithmError('`none` algorithm is not allowed') + + if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized): + raise InvalidAlgorithmError('Algorithms `%s` are not whitelisted. Allowed: %s' % + (algorithms, ALGORITHM_WHITELIST)) + + return _StrictJWT().decode(jwt, key, verify, algorithms, options, **kwargs) + + def exp_max_s_option(max_exp_s): + """ Returns an options dictionary that sets the maximum expiration seconds for a JWT. """ return { 'exp_max_s': max_exp_s, } -decode = StrictJWT().decode - - def jwk_dict_to_public_key(jwk): """ Converts the specified JWK into a public key. """ jwkest_key = keyrep(jwk) diff --git a/util/security/test/test_jwtutil.py b/util/security/test/test_jwtutil.py new file mode 100644 index 000000000..f00785cf7 --- /dev/null +++ b/util/security/test/test_jwtutil.py @@ -0,0 +1,119 @@ +import time + +import pytest +import jwt + +from Crypto.PublicKey import RSA +from jwkest.jwk import RSAKey + +from util.security.jwtutil import (decode, exp_max_s_option, jwk_dict_to_public_key, + InvalidTokenError, InvalidAlgorithmError) + + +@pytest.fixture(scope="session") +def private_key(): + return RSA.generate(2048) + + +@pytest.fixture(scope="session") +def private_key_pem(private_key): + return private_key.exportKey('PEM') + + +@pytest.fixture(scope="session") +def public_key(private_key): + return private_key.publickey().exportKey('PEM') + + +def _token_data(audience, subject, iss, iat=None, exp=None, nbf=None): + return { + 'iss': iss, + 'aud': audience, + 'nbf': nbf() if nbf is not None else int(time.time()), + 'iat': iat() if iat is not None else int(time.time()), + 'exp': exp() if exp is not None else int(time.time() + 3600), + 'sub': subject, + } + + +@pytest.mark.parametrize('aud, iss, nbf, iat, exp, expected_exception', [ + pytest.param('invalidaudience', 'someissuer', None, None, None, 'Invalid audience', + id='invalid audience'), + pytest.param('someaudience', 'invalidissuer', None, None, None, 'Invalid issuer', + id='invalid issuer'), + pytest.param('someaudience', 'someissuer', lambda: time.time() + 120, None, None, + 'The token is not yet valid', + id='invalid not before'), + pytest.param('someaudience', 'someissuer', None, lambda: time.time() + 120, None, + 'Issued At claim', + id='issued at in future'), + pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() - 100, + 'Signature has expired', + id='already expired'), + pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() + 10000, + 'Token was signed for more than', + id='expiration too far in future'), + + pytest.param('someaudience', 'someissuer', lambda: time.time() + 10, None, None, + None, + id='not before in future by within leeway'), + pytest.param('someaudience', 'someissuer', None, lambda: time.time() + 10, None, + None, + id='issued at in future but within leeway'), + pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() - 10, + None, + id='expiration in past but within leeway'), +]) +def test_decode_jwt_validation(aud, iss, nbf, iat, exp, expected_exception, private_key_pem, + public_key): + token = jwt.encode(_token_data(aud, 'subject', iss, iat, exp, nbf), private_key_pem, 'RS256') + + if expected_exception is not None: + with pytest.raises(InvalidTokenError) as ite: + max_exp = exp_max_s_option(3600) + decode(token, public_key, algorithms=['RS256'], audience='someaudience', + issuer='someissuer', options=max_exp, leeway=60) + assert ite.match(expected_exception) + else: + max_exp = exp_max_s_option(3600) + decode(token, public_key, algorithms=['RS256'], audience='someaudience', + issuer='someissuer', options=max_exp, leeway=60) + + +def test_decode_jwt_invalid_key(private_key_pem): + # Encode with the test private key. + token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256') + + # Try to decode with a different public key. + another_public_key = RSA.generate(2048).publickey().exportKey('PEM') + with pytest.raises(InvalidTokenError) as ite: + max_exp = exp_max_s_option(3600) + decode(token, another_public_key, algorithms=['RS256'], audience='aud', + issuer='someissuer', options=max_exp, leeway=60) + assert ite.match('Signature verification failed') + + +def test_decode_jwt_invalid_algorithm(private_key_pem, public_key): + # Encode with the test private key. + token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256') + + # Attempt to decode but only with a different algorithm than that used. + with pytest.raises(InvalidAlgorithmError) as ite: + max_exp = exp_max_s_option(3600) + decode(token, public_key, algorithms=['ES256'], audience='aud', + issuer='someissuer', options=max_exp, leeway=60) + assert ite.match('are not whitelisted') + + +def test_jwk_dict_to_public_key(private_key, private_key_pem): + public_key = private_key.publickey() + jwk = RSAKey(key=private_key.publickey()).serialize() + converted = jwk_dict_to_public_key(jwk) + + # Encode with the test private key. + token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256') + + # Decode with the converted key. + max_exp = exp_max_s_option(3600) + decode(token, converted, algorithms=['RS256'], audience='aud', + issuer='someissuer', options=max_exp, leeway=60)