Implement a basic test suite for jwtutil and add extra checks to the decode method
This commit is contained in:
parent
436e8cb760
commit
4868f17832
2 changed files with 149 additions and 9 deletions
|
@ -5,7 +5,8 @@ from datetime import datetime, timedelta
|
||||||
from jwt import PyJWT
|
from jwt import PyJWT
|
||||||
from jwt.exceptions import (
|
from jwt.exceptions import (
|
||||||
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
|
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
|
||||||
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
|
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError,
|
||||||
|
InvalidAlgorithmError
|
||||||
)
|
)
|
||||||
|
|
||||||
from cryptography.hazmat.backends import default_backend
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
@ -17,14 +18,19 @@ from jwkest.jwk import keyrep, RSAKey, ECKey
|
||||||
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
|
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
|
||||||
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
|
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
|
||||||
|
|
||||||
|
# ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD
|
||||||
|
# `none` here!
|
||||||
|
ALGORITHM_WHITELIST = [
|
||||||
|
'rs256'
|
||||||
|
]
|
||||||
|
|
||||||
class StrictJWT(PyJWT):
|
class _StrictJWT(PyJWT):
|
||||||
""" StrictJWT defines a JWT decoder with extra checks. """
|
""" _StrictJWT defines a JWT decoder with extra checks. """
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_default_options():
|
def _get_default_options():
|
||||||
# Weird syntax to call super on a staticmethod
|
# Weird syntax to call super on a staticmethod
|
||||||
defaults = super(StrictJWT, StrictJWT)._get_default_options()
|
defaults = super(_StrictJWT, _StrictJWT)._get_default_options()
|
||||||
defaults.update({
|
defaults.update({
|
||||||
'require_exp': True,
|
'require_exp': True,
|
||||||
'require_iat': True,
|
'require_iat': True,
|
||||||
|
@ -41,14 +47,14 @@ class StrictJWT(PyJWT):
|
||||||
options['verify_exp'] = True
|
options['verify_exp'] = True
|
||||||
|
|
||||||
# Do all of the other checks
|
# Do all of the other checks
|
||||||
super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
|
super(_StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
|
||||||
|
|
||||||
now = timegm(datetime.utcnow().utctimetuple())
|
now = timegm(datetime.utcnow().utctimetuple())
|
||||||
self._reject_future_iat(payload, now, leeway)
|
self._reject_future_iat(payload, now, leeway)
|
||||||
|
|
||||||
if 'exp' in payload and options.get('exp_max_s') is not None:
|
if 'exp' in payload and options.get('exp_max_s') is not None:
|
||||||
# Validate that the expiration was not more than exp_max_s seconds after the issue time
|
# Validate that the expiration was not more than exp_max_s seconds after the issue time
|
||||||
# or in the absense of an issue time, more than exp_max_s in the future from now
|
# or in the absence of an issue time, more than exp_max_s in the future from now
|
||||||
|
|
||||||
# This will work because the parent method already checked the type of exp
|
# This will work because the parent method already checked the type of exp
|
||||||
expiration = datetime.utcfromtimestamp(int(payload['exp']))
|
expiration = datetime.utcfromtimestamp(int(payload['exp']))
|
||||||
|
@ -73,15 +79,30 @@ class StrictJWT(PyJWT):
|
||||||
' the future.')
|
' the future.')
|
||||||
|
|
||||||
|
|
||||||
|
def decode(jwt, key='', verify=True, algorithms=None, options=None,
|
||||||
|
**kwargs):
|
||||||
|
""" Decodes a JWT. """
|
||||||
|
if not algorithms:
|
||||||
|
raise InvalidAlgorithmError('algorithms must be specified')
|
||||||
|
|
||||||
|
normalized = set([a.lower() for a in algorithms])
|
||||||
|
if 'none' in normalized:
|
||||||
|
raise InvalidAlgorithmError('`none` algorithm is not allowed')
|
||||||
|
|
||||||
|
if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized):
|
||||||
|
raise InvalidAlgorithmError('Algorithms `%s` are not whitelisted. Allowed: %s' %
|
||||||
|
(algorithms, ALGORITHM_WHITELIST))
|
||||||
|
|
||||||
|
return _StrictJWT().decode(jwt, key, verify, algorithms, options, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def exp_max_s_option(max_exp_s):
|
def exp_max_s_option(max_exp_s):
|
||||||
|
""" Returns an options dictionary that sets the maximum expiration seconds for a JWT. """
|
||||||
return {
|
return {
|
||||||
'exp_max_s': max_exp_s,
|
'exp_max_s': max_exp_s,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
decode = StrictJWT().decode
|
|
||||||
|
|
||||||
|
|
||||||
def jwk_dict_to_public_key(jwk):
|
def jwk_dict_to_public_key(jwk):
|
||||||
""" Converts the specified JWK into a public key. """
|
""" Converts the specified JWK into a public key. """
|
||||||
jwkest_key = keyrep(jwk)
|
jwkest_key = keyrep(jwk)
|
||||||
|
|
119
util/security/test/test_jwtutil.py
Normal file
119
util/security/test/test_jwtutil.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from Crypto.PublicKey import RSA
|
||||||
|
from jwkest.jwk import RSAKey
|
||||||
|
|
||||||
|
from util.security.jwtutil import (decode, exp_max_s_option, jwk_dict_to_public_key,
|
||||||
|
InvalidTokenError, InvalidAlgorithmError)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def private_key():
|
||||||
|
return RSA.generate(2048)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def private_key_pem(private_key):
|
||||||
|
return private_key.exportKey('PEM')
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def public_key(private_key):
|
||||||
|
return private_key.publickey().exportKey('PEM')
|
||||||
|
|
||||||
|
|
||||||
|
def _token_data(audience, subject, iss, iat=None, exp=None, nbf=None):
|
||||||
|
return {
|
||||||
|
'iss': iss,
|
||||||
|
'aud': audience,
|
||||||
|
'nbf': nbf() if nbf is not None else int(time.time()),
|
||||||
|
'iat': iat() if iat is not None else int(time.time()),
|
||||||
|
'exp': exp() if exp is not None else int(time.time() + 3600),
|
||||||
|
'sub': subject,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('aud, iss, nbf, iat, exp, expected_exception', [
|
||||||
|
pytest.param('invalidaudience', 'someissuer', None, None, None, 'Invalid audience',
|
||||||
|
id='invalid audience'),
|
||||||
|
pytest.param('someaudience', 'invalidissuer', None, None, None, 'Invalid issuer',
|
||||||
|
id='invalid issuer'),
|
||||||
|
pytest.param('someaudience', 'someissuer', lambda: time.time() + 120, None, None,
|
||||||
|
'The token is not yet valid',
|
||||||
|
id='invalid not before'),
|
||||||
|
pytest.param('someaudience', 'someissuer', None, lambda: time.time() + 120, None,
|
||||||
|
'Issued At claim',
|
||||||
|
id='issued at in future'),
|
||||||
|
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() - 100,
|
||||||
|
'Signature has expired',
|
||||||
|
id='already expired'),
|
||||||
|
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() + 10000,
|
||||||
|
'Token was signed for more than',
|
||||||
|
id='expiration too far in future'),
|
||||||
|
|
||||||
|
pytest.param('someaudience', 'someissuer', lambda: time.time() + 10, None, None,
|
||||||
|
None,
|
||||||
|
id='not before in future by within leeway'),
|
||||||
|
pytest.param('someaudience', 'someissuer', None, lambda: time.time() + 10, None,
|
||||||
|
None,
|
||||||
|
id='issued at in future but within leeway'),
|
||||||
|
pytest.param('someaudience', 'someissuer', None, None, lambda: time.time() - 10,
|
||||||
|
None,
|
||||||
|
id='expiration in past but within leeway'),
|
||||||
|
])
|
||||||
|
def test_decode_jwt_validation(aud, iss, nbf, iat, exp, expected_exception, private_key_pem,
|
||||||
|
public_key):
|
||||||
|
token = jwt.encode(_token_data(aud, 'subject', iss, iat, exp, nbf), private_key_pem, 'RS256')
|
||||||
|
|
||||||
|
if expected_exception is not None:
|
||||||
|
with pytest.raises(InvalidTokenError) as ite:
|
||||||
|
max_exp = exp_max_s_option(3600)
|
||||||
|
decode(token, public_key, algorithms=['RS256'], audience='someaudience',
|
||||||
|
issuer='someissuer', options=max_exp, leeway=60)
|
||||||
|
assert ite.match(expected_exception)
|
||||||
|
else:
|
||||||
|
max_exp = exp_max_s_option(3600)
|
||||||
|
decode(token, public_key, algorithms=['RS256'], audience='someaudience',
|
||||||
|
issuer='someissuer', options=max_exp, leeway=60)
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_jwt_invalid_key(private_key_pem):
|
||||||
|
# Encode with the test private key.
|
||||||
|
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
|
||||||
|
|
||||||
|
# Try to decode with a different public key.
|
||||||
|
another_public_key = RSA.generate(2048).publickey().exportKey('PEM')
|
||||||
|
with pytest.raises(InvalidTokenError) as ite:
|
||||||
|
max_exp = exp_max_s_option(3600)
|
||||||
|
decode(token, another_public_key, algorithms=['RS256'], audience='aud',
|
||||||
|
issuer='someissuer', options=max_exp, leeway=60)
|
||||||
|
assert ite.match('Signature verification failed')
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_jwt_invalid_algorithm(private_key_pem, public_key):
|
||||||
|
# Encode with the test private key.
|
||||||
|
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
|
||||||
|
|
||||||
|
# Attempt to decode but only with a different algorithm than that used.
|
||||||
|
with pytest.raises(InvalidAlgorithmError) as ite:
|
||||||
|
max_exp = exp_max_s_option(3600)
|
||||||
|
decode(token, public_key, algorithms=['ES256'], audience='aud',
|
||||||
|
issuer='someissuer', options=max_exp, leeway=60)
|
||||||
|
assert ite.match('are not whitelisted')
|
||||||
|
|
||||||
|
|
||||||
|
def test_jwk_dict_to_public_key(private_key, private_key_pem):
|
||||||
|
public_key = private_key.publickey()
|
||||||
|
jwk = RSAKey(key=private_key.publickey()).serialize()
|
||||||
|
converted = jwk_dict_to_public_key(jwk)
|
||||||
|
|
||||||
|
# Encode with the test private key.
|
||||||
|
token = jwt.encode(_token_data('aud', 'subject', 'someissuer'), private_key_pem, 'RS256')
|
||||||
|
|
||||||
|
# Decode with the converted key.
|
||||||
|
max_exp = exp_max_s_option(3600)
|
||||||
|
decode(token, converted, algorithms=['RS256'], audience='aud',
|
||||||
|
issuer='someissuer', options=max_exp, leeway=60)
|
Reference in a new issue