81 lines
		
	
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			81 lines
		
	
	
	
		
			2.9 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import re
 | |
| 
 | |
| from datetime import datetime, timedelta
 | |
| from jwt import PyJWT
 | |
| from jwt.exceptions import (
 | |
|   InvalidTokenError, DecodeError, InvalidAudienceError, ExpiredSignatureError,
 | |
|   ImmatureSignatureError, InvalidIssuedAtError, InvalidIssuerError, MissingRequiredClaimError
 | |
| )
 | |
| 
 | |
| from cryptography.hazmat.backends import default_backend
 | |
| from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
 | |
| from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers
 | |
| from jwkest.jwk import keyrep, RSAKey, ECKey
 | |
| 
 | |
| 
 | |
| # TOKEN_REGEX defines a regular expression for matching JWT bearer tokens.
 | |
| TOKEN_REGEX = re.compile(r'\ABearer (([a-zA-Z0-9+\-_/]+\.)+[a-zA-Z0-9+\-_/]+)\Z')
 | |
| 
 | |
| 
 | |
| class StrictJWT(PyJWT):
 | |
|   """ StrictJWT defines a JWT decoder with extra checks. """
 | |
| 
 | |
|   @staticmethod
 | |
|   def _get_default_options():
 | |
|     # Weird syntax to call super on a staticmethod
 | |
|     defaults = super(StrictJWT, StrictJWT)._get_default_options()
 | |
|     defaults.update({
 | |
|       'require_exp': True,
 | |
|       'require_iat': True,
 | |
|       'require_nbf': True,
 | |
|       'exp_max_s': None,
 | |
|     })
 | |
|     return defaults
 | |
| 
 | |
|   def _validate_claims(self, payload, options, audience=None, issuer=None, leeway=0, **kwargs):
 | |
|     if options.get('exp_max_s') is not None:
 | |
|       if 'verify_expiration' in kwargs and not kwargs.get('verify_expiration'):
 | |
|         raise ValueError('exp_max_s option implies verify_expiration')
 | |
| 
 | |
|       options['verify_exp'] = True
 | |
| 
 | |
|     # Do all of the other checks
 | |
|     super(StrictJWT, self)._validate_claims(payload, options, audience, issuer, leeway, **kwargs)
 | |
| 
 | |
|     if 'exp' in payload and options.get('exp_max_s') is not None:
 | |
|       # Validate that the expiration was not more than exp_max_s seconds after the issue time
 | |
|       # or in the absense of an issue time, more than exp_max_s in the future from now
 | |
| 
 | |
|       # This will work because the parent method already checked the type of exp
 | |
|       expiration = datetime.utcfromtimestamp(int(payload['exp']))
 | |
|       max_signed_s = options.get('exp_max_s')
 | |
| 
 | |
|       start_time = datetime.utcnow()
 | |
|       if 'iat' in payload:
 | |
|         start_time = datetime.utcfromtimestamp(int(payload['iat']))
 | |
| 
 | |
|       if expiration > start_time + timedelta(seconds=max_signed_s):
 | |
|         raise InvalidTokenError('Token was signed for more than %s seconds from %s', max_signed_s,
 | |
|                                 start_time)
 | |
| 
 | |
| 
 | |
| def exp_max_s_option(max_exp_s):
 | |
|   return {
 | |
|     'exp_max_s': max_exp_s,
 | |
|   }
 | |
| 
 | |
| 
 | |
| decode = StrictJWT().decode
 | |
| 
 | |
| 
 | |
| def jwk_dict_to_public_key(jwk):
 | |
|   """ Converts the specified JWK into a public key. """
 | |
|   jwkest_key = keyrep(jwk)
 | |
|   if isinstance(jwkest_key, RSAKey):
 | |
|     pycrypto_key = jwkest_key.key
 | |
|     return RSAPublicNumbers(e=pycrypto_key.e, n=pycrypto_key.n).public_key(default_backend())
 | |
|   elif isinstance(jwkest_key, ECKey):
 | |
|     x, y = jwkest_key.get_key()
 | |
|     return EllipticCurvePublicNumbers(x, y, jwkest_key.curve).public_key(default_backend())
 | |
| 
 | |
|   raise Exception('Unsupported kind of JWK: %s', str(type(jwkest_key)))
 |