import time import json import logging import urlparse from cachetools import lru_cache from cachetools.ttl import TTLCache from util.oauth.base import OAuthService logger = logging.getLogger(__name__) def decode_user_jwt(token, oidc_provider): try: return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'], audience=oidc_provider.client_id(), issuer=oidc_provider.issuer) except InvalidTokenError: # Public key may have expired. Try to retrieve an updated public key and use it to decode. return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'], audience=oidc_provider.client_id(), issuer=oidc_provider.issuer) OIDC_WELLKNOWN = ".well-known/openid-configuration" PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour class OIDCConfig(OAuthService): def __init__(self, config, key_name): super(OIDCConfig, self).__init__(config, key_name) self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key) self._config = config self._http_client = config['HTTPCLIENT'] @lru_cache(maxsize=1) def _oidc_config(self): if self.config.get('OIDC_SERVER'): return self._load_via_discovery(self._config.get('DEBUGGING', False)) else: return {} def _load_via_discovery(self, is_debugging): oidc_server = self.config['OIDC_SERVER'] if not oidc_server.startswith('https://') and not is_debugging: raise Exception('OIDC server must be accessed over SSL') discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN) discovery = self._http_client.get(discovery_url, timeout=5) if discovery.status_code / 100 != 2: raise Exception("Could not load OIDC discovery information") try: return json.loads(discovery.text) except ValueError: logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) raise Exception("Could not parse OIDC discovery information") def authorize_endpoint(self): return self._oidc_config().get('authorization_endpoint', '') + '?' def token_endpoint(self): return self._oidc_config().get('token_endpoint') def user_endpoint(self): return None def validate_client_id_and_secret(self, http_client, app_config): pass def get_public_config(self): return { 'CLIENT_ID': self.client_id(), 'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), 'OIDC': True, } @property def issuer(self): return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER']) def get_public_key(self, force_refresh=False): """ Retrieves the public key for this handler. """ # If force_refresh is true, we expire all the items in the cache by setting the time to # the current time + the expiration TTL. if force_refresh: self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL) # Retrieve the public key from the cache. If the cache does not contain the public key, # it will internally call _get_public_key to retrieve it and then save it. The None is # a random key chose to be stored in the cache, and could be anything. return self._public_key_cache[None] def _get_public_key(self, _): """ Retrieves the public key for this handler. """ keys_url = self._oidc_config()['jwks_uri'] keys = KEYS() keys.load_from_url(keys_url) if not list(keys): raise Exception('No keys provided by OIDC provider') rsa_key = list(keys)[0] rsa_key.deserialize() # Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing # issues. return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())