81 lines
2.9 KiB
Python
81 lines
2.9 KiB
Python
import re
|
|
|
|
from datetime import datetime, timedelta
|
|
from jwt import PyJWT
|
|
from jwt.exceptions import (
|
|
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
|
|
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
|
|
)
|
|
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
|
|
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
|
|
from jwkest.jwk import keyrep, RSAKey, ECKey
|
|
|
|
|
|
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
|
|
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
|
|
|
|
|
|
class StrictJWT(PyJWT):
|
|
""" StrictJWT defines a JWT decoder with extra checks. """
|
|
|
|
@staticmethod
|
|
def _get_default_options():
|
|
# Weird syntax to call super on a staticmethod
|
|
defaults = super(StrictJWT, StrictJWT)._get_default_options()
|
|
defaults.update({
|
|
'require_exp': True,
|
|
'require_iat': True,
|
|
'require_nbf': True,
|
|
'exp_max_s': None,
|
|
})
|
|
return defaults
|
|
|
|
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
|
|
if options.get('exp_max_s') is not None:
|
|
if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
|
|
raise ValueError('exp_max_s option implies verify_expiration')
|
|
|
|
options['verify_exp'] = True
|
|
|
|
# Do all of the other checks
|
|
super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
|
|
|
|
if 'exp' in payload and options.get('exp_max_s') is not None:
|
|
# Validate that the expiration was not more than exp_max_s seconds after the issue time
|
|
# or in the absense of an issue time, more than exp_max_s in the future from now
|
|
|
|
# This will work because the parent method already checked the type of exp
|
|
expiration = datetime.utcfromtimestamp(int(payload['exp']))
|
|
max_signed_s = options.get('exp_max_s')
|
|
|
|
start_time = datetime.utcnow()
|
|
if 'iat' in payload:
|
|
start_time = datetime.utcfromtimestamp(int(payload['iat']))
|
|
|
|
if expiration > start_time + timedelta(seconds=max_signed_s):
|
|
raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
|
|
start_time)
|
|
|
|
|
|
def exp_max_s_option(max_exp_s):
|
|
return {
|
|
'exp_max_s': max_exp_s,
|
|
}
|
|
|
|
|
|
decode = StrictJWT().decode
|
|
|
|
|
|
def jwk_dict_to_public_key(jwk):
|
|
""" Converts the specified JWK into a public key. """
|
|
jwkest_key = keyrep(jwk)
|
|
if isinstance(jwkest_key, RSAKey):
|
|
pycrypto_key = jwkest_key.key
|
|
return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend())
|
|
elif isinstance(jwkest_key, ECKey):
|
|
x, y = jwkest_key.get_key()
|
|
return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend())
|
|
|
|
raise Exception('Unsupported kind of JWK: %s', str(type(jwkest_key)))
|