from datetime import datetime, timedelta from jwt import PyJWT from jwt.exceptions import ( InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError, ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError ) class StrictJWT(PyJWT): @staticmethod def _get_default_options(): # Weird syntax to call super on a staticmethod defaults = super(StrictJWT, StrictJWT)._get_default_options() defaults.update({ 'require_exp': True, 'require_iat': True, 'require_nbf': True, 'exp_max_s': None, }) return defaults def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs): if options.get('exp_max_s') is not None: if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'): raise ValueError('exp_max_s option implies verify_expiration') options['verify_exp'] = True # Do all of the other checks super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs) if 'exp' in payload and options.get('exp_max_s') is not None: # Validate that the expiration was not more than exp_max_s seconds after the issue time # or in the absense of an issue time, more than exp_max_s in the future from now # This will work because the parent method already checked the type of exp expiration = datetime.utcfromtimestamp(int(payload['exp'])) max_signed_s = options.get('exp_max_s') start_time = datetime.utcnow() if 'iat' in payload: start_time = datetime.utcfromtimestamp(int(payload['iat'])) if expiration > start_time + timedelta(seconds=max_signed_s): raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s, start_time) def exp_max_s_option(max_exp_s): return { 'exp_max_s': max_exp_s, } decode = StrictJWT().decode