116 lines
4.2 KiB
Python
116 lines
4.2 KiB
Python
import re
|
|
|
|
from calendar import timegm
|
|
from datetime import datetime, timedelta
|
|
from jwt import PyJWT
|
|
from jwt.exceptions import (
|
|
InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
|
|
ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError,
|
|
InvalidAlgorithmError
|
|
)
|
|
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
|
|
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
|
|
from jwkest.jwk import keyrep, RSAKey, ECKey
|
|
|
|
|
|
# TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
|
|
TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
|
|
|
|
# ALGORITHM_WHITELIST defines a whitelist of allowed algorithms to be used in JWTs. DO NOT ADD
|
|
# `none` here!
|
|
ALGORITHM_WHITELIST = [
|
|
'rs256'
|
|
]
|
|
|
|
class _StrictJWT(PyJWT):
|
|
""" _StrictJWT defines a JWT decoder with extra checks. """
|
|
|
|
@staticmethod
|
|
def _get_default_options():
|
|
# Weird syntax to call super on a staticmethod
|
|
defaults = super(_StrictJWT, _StrictJWT)._get_default_options()
|
|
defaults.update({
|
|
'require_exp': True,
|
|
'require_iat': True,
|
|
'require_nbf': True,
|
|
'exp_max_s': None,
|
|
})
|
|
return defaults
|
|
|
|
def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
|
|
if options.get('exp_max_s') is not None:
|
|
if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
|
|
raise ValueError('exp_max_s option implies verify_expiration')
|
|
|
|
options['verify_exp'] = True
|
|
|
|
# Do all of the other checks
|
|
super(_StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
|
|
|
|
now = timegm(datetime.utcnow().utctimetuple())
|
|
self._reject_future_iat(payload, now, leeway)
|
|
|
|
if 'exp' in payload and options.get('exp_max_s') is not None:
|
|
# Validate that the expiration was not more than exp_max_s seconds after the issue time
|
|
# or in the absence of an issue time, more than exp_max_s in the future from now
|
|
|
|
# This will work because the parent method already checked the type of exp
|
|
expiration = datetime.utcfromtimestamp(int(payload['exp']))
|
|
max_signed_s = options.get('exp_max_s')
|
|
|
|
start_time = datetime.utcnow()
|
|
if 'iat' in payload:
|
|
start_time = datetime.utcfromtimestamp(int(payload['iat']))
|
|
|
|
if expiration > start_time + timedelta(seconds=max_signed_s):
|
|
raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
|
|
start_time)
|
|
|
|
def _reject_future_iat(self, payload, now, leeway):
|
|
try:
|
|
iat = int(payload['iat'])
|
|
except ValueError:
|
|
raise DecodeError('Issued At claim (iat) must be an integer.')
|
|
|
|
if iat > (now + leeway):
|
|
raise InvalidIssuedAtError('Issued At claim (iat) cannot be in'
|
|
' the future.')
|
|
|
|
|
|
def decode(jwt, key='', verify=True, algorithms=None, options=None,
|
|
**kwargs):
|
|
""" Decodes a JWT. """
|
|
if not algorithms:
|
|
raise InvalidAlgorithmError('algorithms must be specified')
|
|
|
|
normalized = set([a.lower() for a in algorithms])
|
|
if 'none' in normalized:
|
|
raise InvalidAlgorithmError('`none` algorithm is not allowed')
|
|
|
|
if set(normalized).intersection(set(ALGORITHM_WHITELIST)) != set(normalized):
|
|
raise InvalidAlgorithmError('Algorithms `%s` are not whitelisted. Allowed: %s' %
|
|
(algorithms, ALGORITHM_WHITELIST))
|
|
|
|
return _StrictJWT().decode(jwt, key, verify, algorithms, options, **kwargs)
|
|
|
|
|
|
def exp_max_s_option(max_exp_s):
|
|
""" Returns an options dictionary that sets the maximum expiration seconds for a JWT. """
|
|
return {
|
|
'exp_max_s': max_exp_s,
|
|
}
|
|
|
|
|
|
def jwk_dict_to_public_key(jwk):
|
|
""" Converts the specified JWK into a public key. """
|
|
jwkest_key = keyrep(jwk)
|
|
if isinstance(jwkest_key, RSAKey):
|
|
pycrypto_key = jwkest_key.key
|
|
return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend())
|
|
elif isinstance(jwkest_key, ECKey):
|
|
x, y = jwkest_key.get_key()
|
|
return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend())
|
|
|
|
raise Exception('Unsupported kind of JWK: %s', str(type(jwkest_key)))
|