a9791ea419
This makes the OIDC lookup lazy, ensuring that the rest of the registry and app continues working even if one OIDC provider goes down.
237 lines
8.9 KiB
Python
237 lines
8.9 KiB
Python
import time
|
|
import json
|
|
import logging
|
|
import urlparse
|
|
|
|
import jwt
|
|
|
|
from cachetools import lru_cache
|
|
from cachetools.ttl import TTLCache
|
|
from cryptography.hazmat.backends import default_backend
|
|
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
|
from jwkest.jwk import KEYS
|
|
|
|
from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException
|
|
from oauth.login import OAuthLoginException
|
|
from util.security.jwtutil import decode, InvalidTokenError
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
OIDC_WELLKNOWN = ".well-known/openid-configuration"
|
|
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
|
|
ALLOWED_ALGORITHMS = ['RS256']
|
|
JWT_CLOCK_SKEW_SECONDS = 30
|
|
|
|
class DiscoveryFailureException(Exception):
|
|
""" Exception raised when OIDC discovery fails. """
|
|
pass
|
|
|
|
class PublicKeyLoadException(Exception):
|
|
""" Exception raised if loading the OIDC public key fails. """
|
|
pass
|
|
|
|
|
|
class OIDCLoginService(OAuthService):
|
|
""" Defines a generic service for all OpenID-connect compatible login services. """
|
|
def __init__(self, config, key_name):
|
|
super(OIDCLoginService, self).__init__(config, key_name)
|
|
|
|
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._load_public_key)
|
|
self._id = key_name[0:key_name.find('_')].lower()
|
|
self._http_client = config['HTTPCLIENT']
|
|
self._mailing = config.get('FEATURE_MAILING', False)
|
|
|
|
def service_id(self):
|
|
return self._id
|
|
|
|
def service_name(self):
|
|
return self.config.get('SERVICE_NAME', self.service_id())
|
|
|
|
def get_icon(self):
|
|
return self.config.get('SERVICE_ICON', 'fa-user-circle')
|
|
|
|
def get_login_scopes(self):
|
|
default_scopes = ['openid']
|
|
|
|
if self.user_endpoint() is not None:
|
|
default_scopes.append('profile')
|
|
|
|
if self._mailing:
|
|
default_scopes.append('email')
|
|
|
|
return self._oidc_config().get('scopes_supported', default_scopes)
|
|
|
|
def authorize_endpoint(self):
|
|
return self._oidc_config().get('authorization_endpoint', '') + '?response_type=code&'
|
|
|
|
def token_endpoint(self):
|
|
return self._oidc_config().get('token_endpoint')
|
|
|
|
def user_endpoint(self):
|
|
return self._oidc_config().get('userinfo_endpoint')
|
|
|
|
def validate_client_id_and_secret(self, http_client, app_config):
|
|
# TODO: find a way to verify client secret too.
|
|
check_auth_url = http_client.get(self.get_auth_url())
|
|
if check_auth_url.status_code // 100 != 2:
|
|
raise Exception('Got non-200 status code for authorization endpoint')
|
|
|
|
def requires_form_encoding(self):
|
|
return True
|
|
|
|
def get_public_config(self):
|
|
return {
|
|
'CLIENT_ID': self.client_id(),
|
|
'OIDC': True,
|
|
}
|
|
|
|
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
|
|
# Exchange the code for the access token and id_token
|
|
try:
|
|
json_data = self.exchange_code(app_config, http_client, code,
|
|
redirect_suffix=redirect_suffix,
|
|
form_encode=self.requires_form_encoding())
|
|
except OAuthExchangeCodeException as oce:
|
|
raise OAuthLoginException(oce.message)
|
|
|
|
# Make sure we received both.
|
|
access_token = json_data.get('access_token', None)
|
|
if access_token is None:
|
|
logger.debug('Missing access_token in response: %s', json_data)
|
|
raise OAuthLoginException('Missing `access_token` in OIDC response')
|
|
|
|
id_token = json_data.get('id_token', None)
|
|
if id_token is None:
|
|
logger.debug('Missing id_token in response: %s', json_data)
|
|
raise OAuthLoginException('Missing `id_token` in OIDC response')
|
|
|
|
# Decode the id_token.
|
|
try:
|
|
decoded_id_token = self._decode_user_jwt(id_token)
|
|
except InvalidTokenError as ite:
|
|
logger.exception('Got invalid token error on OIDC decode: %s', ite.message)
|
|
raise OAuthLoginException('Could not decode OIDC token')
|
|
except PublicKeyLoadException as pke:
|
|
logger.exception('Could not load public key during OIDC decode: %s', pke.message)
|
|
raise OAuthLoginException('Could find public OIDC key')
|
|
|
|
# Retrieve the user information.
|
|
try:
|
|
user_info = self.get_user_info(http_client, access_token)
|
|
except OAuthGetUserInfoException as oge:
|
|
raise OAuthLoginException(oge.message)
|
|
|
|
# Verify subs.
|
|
if user_info['sub'] != decoded_id_token['sub']:
|
|
raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint')
|
|
|
|
# Check if we have a verified email address.
|
|
email_address = user_info.get('email') if user_info.get('email_verified') else None
|
|
if self._mailing:
|
|
if email_address is None:
|
|
raise OAuthLoginException('A verified email address is required to login with this service')
|
|
|
|
# Check for a preferred username.
|
|
lusername = user_info.get('preferred_username') or user_info.get('sub')
|
|
return decoded_id_token['sub'], lusername, email_address
|
|
|
|
@property
|
|
def _issuer(self):
|
|
return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
|
|
|
|
@lru_cache(maxsize=1)
|
|
def _oidc_config(self):
|
|
if self.config.get('OIDC_SERVER'):
|
|
return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False))
|
|
else:
|
|
return {}
|
|
|
|
def _load_oidc_config_via_discovery(self, is_debugging):
|
|
""" Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True,
|
|
non-secure connections are alllowed. Raises an DiscoveryFailureException on failure.
|
|
"""
|
|
oidc_server = self.config['OIDC_SERVER']
|
|
if not oidc_server.startswith('https://') and not is_debugging:
|
|
raise DiscoveryFailureException('OIDC server must be accessed over SSL')
|
|
|
|
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
|
|
discovery = self._http_client.get(discovery_url, timeout=5, verify=not is_debugging)
|
|
if discovery.status_code // 100 != 2:
|
|
logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text)
|
|
raise DiscoveryFailureException("Could not load OIDC discovery information")
|
|
|
|
try:
|
|
return json.loads(discovery.text)
|
|
except ValueError:
|
|
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
|
raise DiscoveryFailureException("Could not parse OIDC discovery information")
|
|
|
|
def _decode_user_jwt(self, token):
|
|
""" Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError
|
|
exception on an invalid token or a PublicKeyLoadException if the public key could not be
|
|
loaded for decoding.
|
|
"""
|
|
# Find the key to use.
|
|
headers = jwt.get_unverified_header(token)
|
|
kid = headers.get('kid', None)
|
|
if kid is None:
|
|
raise InvalidTokenError('Missing `kid` header')
|
|
|
|
try:
|
|
return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS,
|
|
audience=self.client_id(),
|
|
issuer=self._issuer,
|
|
leeway=JWT_CLOCK_SKEW_SECONDS,
|
|
options=dict(require_nbf=False))
|
|
except InvalidTokenError:
|
|
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
|
|
return decode(token, self._get_public_key(kid, force_refresh=True),
|
|
algorithms=ALLOWED_ALGORITHMS,
|
|
audience=self.client_id(),
|
|
issuer=self._issuer,
|
|
leeway=JWT_CLOCK_SKEW_SECONDS,
|
|
options=dict(require_nbf=False))
|
|
|
|
def _get_public_key(self, kid, force_refresh=False):
|
|
""" Retrieves the public key for this handler with the given kid. Raises a
|
|
PublicKeyLoadException on failure. """
|
|
|
|
# If force_refresh is true, we expire all the items in the cache by setting the time to
|
|
# the current time + the expiration TTL.
|
|
if force_refresh:
|
|
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
|
|
|
|
# Retrieve the public key from the cache. If the cache does not contain the public key,
|
|
# it will internally call _load_public_key to retrieve it and then save it.
|
|
return self._public_key_cache[kid]
|
|
|
|
def _load_public_key(self, kid):
|
|
""" Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException
|
|
on failure.
|
|
"""
|
|
keys_url = self._oidc_config()['jwks_uri']
|
|
|
|
# Load the keys.
|
|
try:
|
|
keys = KEYS()
|
|
keys.load_from_url(keys_url, verify=not self.config.get('DEBUGGING', False))
|
|
except Exception as ex:
|
|
logger.exception('Exception loading public key')
|
|
raise PublicKeyLoadException(ex.message)
|
|
|
|
# Find the matching key.
|
|
keys_found = keys.by_kid(kid)
|
|
if len(keys_found) == 0:
|
|
raise PublicKeyLoadException('Public key %s not found' % kid)
|
|
|
|
rsa_keys = [key for key in keys_found if key.kty == 'RSA']
|
|
if len(rsa_keys) == 0:
|
|
raise PublicKeyLoadException('No RSA form of public key %s not found' % kid)
|
|
|
|
matching_key = rsa_keys[0]
|
|
matching_key.deserialize()
|
|
|
|
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
|
|
# issues.
|
|
return load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend())
|