111 lines
3.7 KiB
Python
111 lines
3.7 KiB
Python
|
import time
|
||
|
import json
|
||
|
import logging
|
||
|
import urlparse
|
||
|
|
||
|
from cachetools import lru_cache
|
||
|
from cachetools.ttl import TTLCache
|
||
|
|
||
|
from util.oauth.base import OAuthService
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def decode_user_jwt(token, oidc_provider):
|
||
|
try:
|
||
|
return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'],
|
||
|
audience=oidc_provider.client_id(),
|
||
|
issuer=oidc_provider.issuer)
|
||
|
except InvalidTokenError:
|
||
|
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
|
||
|
return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'],
|
||
|
audience=oidc_provider.client_id(),
|
||
|
issuer=oidc_provider.issuer)
|
||
|
|
||
|
OIDC_WELLKNOWN = ".well-known/openid-configuration"
|
||
|
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
|
||
|
|
||
|
class OIDCConfig(OAuthService):
|
||
|
def __init__(self, config, key_name):
|
||
|
super(OIDCConfig, self).__init__(config, key_name)
|
||
|
|
||
|
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
|
||
|
self._config = config
|
||
|
self._http_client = config['HTTPCLIENT']
|
||
|
|
||
|
@lru_cache(maxsize=1)
|
||
|
def _oidc_config(self):
|
||
|
if self.config.get('OIDC_SERVER'):
|
||
|
return self._load_via_discovery(self._config.get('DEBUGGING', False))
|
||
|
else:
|
||
|
return {}
|
||
|
|
||
|
def _load_via_discovery(self, is_debugging):
|
||
|
oidc_server = self.config['OIDC_SERVER']
|
||
|
if not oidc_server.startswith('https://') and not is_debugging:
|
||
|
raise Exception('OIDC server must be accessed over SSL')
|
||
|
|
||
|
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
|
||
|
discovery = self._http_client.get(discovery_url, timeout=5)
|
||
|
|
||
|
if discovery.status_code / 100 != 2:
|
||
|
raise Exception("Could not load OIDC discovery information")
|
||
|
|
||
|
try:
|
||
|
return json.loads(discovery.text)
|
||
|
except ValueError:
|
||
|
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
||
|
raise Exception("Could not parse OIDC discovery information")
|
||
|
|
||
|
def authorize_endpoint(self):
|
||
|
return self._oidc_config().get('authorization_endpoint', '') + '?'
|
||
|
|
||
|
def token_endpoint(self):
|
||
|
return self._oidc_config().get('token_endpoint')
|
||
|
|
||
|
def user_endpoint(self):
|
||
|
return None
|
||
|
|
||
|
def validate_client_id_and_secret(self, http_client, app_config):
|
||
|
pass
|
||
|
|
||
|
def get_public_config(self):
|
||
|
return {
|
||
|
'CLIENT_ID': self.client_id(),
|
||
|
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
|
||
|
'OIDC': True,
|
||
|
}
|
||
|
|
||
|
@property
|
||
|
def issuer(self):
|
||
|
return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
|
||
|
|
||
|
def get_public_key(self, force_refresh=False):
|
||
|
""" Retrieves the public key for this handler. """
|
||
|
# If force_refresh is true, we expire all the items in the cache by setting the time to
|
||
|
# the current time + the expiration TTL.
|
||
|
if force_refresh:
|
||
|
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
|
||
|
|
||
|
# Retrieve the public key from the cache. If the cache does not contain the public key,
|
||
|
# it will internally call _get_public_key to retrieve it and then save it. The None is
|
||
|
# a random key chose to be stored in the cache, and could be anything.
|
||
|
return self._public_key_cache[None]
|
||
|
|
||
|
def _get_public_key(self, _):
|
||
|
""" Retrieves the public key for this handler. """
|
||
|
keys_url = self._oidc_config()['jwks_uri']
|
||
|
|
||
|
keys = KEYS()
|
||
|
keys.load_from_url(keys_url)
|
||
|
|
||
|
if not list(keys):
|
||
|
raise Exception('No keys provided by OIDC provider')
|
||
|
|
||
|
rsa_key = list(keys)[0]
|
||
|
rsa_key.deserialize()
|
||
|
|
||
|
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
|
||
|
# issues.
|
||
|
return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())
|