Implement a basic test suite for jwtutil and add extra checks to the decode method
This commit is contained in:
parent
436e8cb760
commit
4868f17832
2 changed files with 149 additions and 9 deletions
|
@ -5,7 +5,8 @@ from datetime import datetime, timedelta
|
|||
from jwt import PyJWT
|
||||
from jwt.exceptions import (
|
||||
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
|
||||
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
|
||||
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError,
|
||||
InvalidAlgorithmError
|
||||
)
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
|
@ -17,14 +18,19 @@ from jwkest.jwk import keyrep, RSAKey, ECKey
|
|||
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
|
||||
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
|
||||
|
||||
# ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD
|
||||
# `none` here!
|
||||
ALGORITHM_WHITELIST = [
|
||||
'rs256'
|
||||
]
|
||||
|
||||
class StrictJWT(PyJWT):
|
||||
""" StrictJWT defines a JWT decoder with extra checks. """
|
||||
class _StrictJWT(PyJWT):
|
||||
""" _StrictJWT defines a JWT decoder with extra checks. """
|
||||
|
||||
@staticmethod
|
||||
def _get_default_options():
|
||||
# Weird syntax to call super on a staticmethod
|
||||
defaults = super(StrictJWT, StrictJWT)._get_default_options()
|
||||
defaults = super(_StrictJWT, _StrictJWT)._get_default_options()
|
||||
defaults.update({
|
||||
'require_exp': True,
|
||||
'require_iat': True,
|
||||
|
@ -41,14 +47,14 @@ class StrictJWT(PyJWT):
|
|||
options['verify_exp'] = True
|
||||
|
||||
# Do all of the other checks
|
||||
super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
|
||||
super(_StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
|
||||
|
||||
now = timegm(datetime.utcnow().utctimetuple())
|
||||
self._reject_future_iat(payload, now, leeway)
|
||||
|
||||
if 'exp' in payload and options.get('exp_max_s') is not None:
|
||||
# Validate that the expiration was not more than exp_max_s seconds after the issue time
|
||||
# or in the absense of an issue time, more than exp_max_s in the future from now
|
||||
# or in the absence of an issue time, more than exp_max_s in the future from now
|
||||
|
||||
# This will work because the parent method already checked the type of exp
|
||||
expiration = datetime.utcfromtimestamp(int(payload['exp']))
|
||||
|
@ -73,15 +79,30 @@ class StrictJWT(PyJWT):
|
|||
' the future.')
|
||||
|
||||
|
||||
def decode(jwt, key='', verify=True, algorithms=None, options=None,
|
||||
**kwargs):
|
||||
""" Decodes a JWT. """
|
||||
if not algorithms:
|
||||
raise InvalidAlgorithmError('algorithms must be specified')
|
||||
|
||||
normalized = set([a.lower() for a in algorithms])
|
||||
if 'none' in normalized:
|
||||
raise InvalidAlgorithmError('`none` algorithm is not allowed')
|
||||
|
||||
if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized):
|
||||
raise InvalidAlgorithmError('Algorithms `%s` are not whitelisted. Allowed: %s' %
|
||||
(algorithms, ALGORITHM_WHITELIST))
|
||||
|
||||
return _StrictJWT().decode(jwt, key, verify, algorithms, options, **kwargs)
|
||||
|
||||
|
||||
def exp_max_s_option(max_exp_s):
|
||||
""" Returns an options dictionary that sets the maximum expiration seconds for a JWT. """
|
||||
return {
|
||||
'exp_max_s': max_exp_s,
|
||||
}
|
||||
|
||||
|
||||
decode = StrictJWT().decode
|
||||
|
||||
|
||||
def jwk_dict_to_public_key(jwk):
|
||||
""" Converts the specified JWK into a public key. """
|
||||
jwkest_key = keyrep(jwk)
|
||||
|
|
Reference in a new issue