19f7acf575
Moves all the external login services into a set of classes that share as much code as possible. These services are then registered on both the client and server, allowing us in the followup change to dynamically register new handlers
110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
import time
|
|
import json
|
|
import logging
|
|
import urlparse
|
|
|
|
from cachetools import lru_cache
|
|
from cachetools.ttl import TTLCache
|
|
|
|
from util.oauth.base import OAuthService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def decode_user_jwt(token, oidc_provider):
|
|
try:
|
|
return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'],
|
|
audience=oidc_provider.client_id(),
|
|
issuer=oidc_provider.issuer)
|
|
except InvalidTokenError:
|
|
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
|
|
return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'],
|
|
audience=oidc_provider.client_id(),
|
|
issuer=oidc_provider.issuer)
|
|
|
|
OIDC_WELLKNOWN = ".well-known/openid-configuration"
|
|
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
|
|
|
|
class OIDCConfig(OAuthService):
|
|
def __init__(self, config, key_name):
|
|
super(OIDCConfig, self).__init__(config, key_name)
|
|
|
|
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
|
|
self._config = config
|
|
self._http_client = config['HTTPCLIENT']
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _oidc_config(self):
|
|
if self.config.get('OIDC_SERVER'):
|
|
return self._load_via_discovery(self._config.get('DEBUGGING', False))
|
|
else:
|
|
return {}
|
|
|
|
def _load_via_discovery(self, is_debugging):
|
|
oidc_server = self.config['OIDC_SERVER']
|
|
if not oidc_server.startswith('https://') and not is_debugging:
|
|
raise Exception('OIDC server must be accessed over SSL')
|
|
|
|
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
|
|
discovery = self._http_client.get(discovery_url, timeout=5)
|
|
|
|
if discovery.status_code / 100 != 2:
|
|
raise Exception("Could not load OIDC discovery information")
|
|
|
|
try:
|
|
return json.loads(discovery.text)
|
|
except ValueError:
|
|
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
|
raise Exception("Could not parse OIDC discovery information")
|
|
|
|
def authorize_endpoint(self):
|
|
return self._oidc_config().get('authorization_endpoint', '') + '?'
|
|
|
|
def token_endpoint(self):
|
|
return self._oidc_config().get('token_endpoint')
|
|
|
|
def user_endpoint(self):
|
|
return None
|
|
|
|
def validate_client_id_and_secret(self, http_client, app_config):
|
|
pass
|
|
|
|
def get_public_config(self):
|
|
return {
|
|
'CLIENT_ID': self.client_id(),
|
|
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
|
|
'OIDC': True,
|
|
}
|
|
|
|
@property
|
|
def issuer(self):
|
|
return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
|
|
|
|
def get_public_key(self, force_refresh=False):
|
|
""" Retrieves the public key for this handler. """
|
|
# If force_refresh is true, we expire all the items in the cache by setting the time to
|
|
# the current time + the expiration TTL.
|
|
if force_refresh:
|
|
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
|
|
|
|
# Retrieve the public key from the cache. If the cache does not contain the public key,
|
|
# it will internally call _get_public_key to retrieve it and then save it. The None is
|
|
# a random key chose to be stored in the cache, and could be anything.
|
|
return self._public_key_cache[None]
|
|
|
|
def _get_public_key(self, _):
|
|
""" Retrieves the public key for this handler. """
|
|
keys_url = self._oidc_config()['jwks_uri']
|
|
|
|
keys = KEYS()
|
|
keys.load_from_url(keys_url)
|
|
|
|
if not list(keys):
|
|
raise Exception('No keys provided by OIDC provider')
|
|
|
|
rsa_key = list(keys)[0]
|
|
rsa_key.deserialize()
|
|
|
|
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
|
|
# issues.
|
|
return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())
|