import time import json import logging import urlparse import jwt from cachetools import lru_cache from cachetools.ttl import TTLCache from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_der_public_key from jwkest.jwk import KEYS from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException from oauth.login import OAuthLoginException from util.security.jwtutil import decode, InvalidTokenError logger = logging.getLogger(__name__) OIDC_WELLKNOWN = ".well-known/openid-configuration" PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour ALLOWED_ALGORITHMS = ['RS256'] JWT_CLOCK_SKEW_SECONDS = 30 class DiscoveryFailureException(Exception): """ Exception raised when OIDC discovery fails. """ pass class PublicKeyLoadException(Exception): """ Exception raised if loading the OIDC public key fails. """ pass class OIDCLoginService(OAuthService): """ Defines a generic service for all OpenID-connect compatible login services. """ def __init__(self, config, key_name, client=None): super(OIDCLoginService, self).__init__(config, key_name) self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._load_public_key) self._id = key_name[0:key_name.find('_')].lower() self._http_client = client or config['HTTPCLIENT'] self._mailing = config.get('FEATURE_MAILING', False) def service_id(self): return self._id def service_name(self): return self.config.get('SERVICE_NAME', self.service_id()) def get_icon(self): return self.config.get('SERVICE_ICON', 'fa-user-circle') def get_login_scopes(self): default_scopes = ['openid'] if self.user_endpoint() is not None: default_scopes.append('profile') if self._mailing: default_scopes.append('email') return self._oidc_config().get('scopes_supported', default_scopes) def authorize_endpoint(self): return self._oidc_config().get('authorization_endpoint', '') + '?response_type=code&' def token_endpoint(self): return self._oidc_config().get('token_endpoint') def user_endpoint(self): return self._oidc_config().get('userinfo_endpoint') def validate(self): return bool(self.user_endpoint()) def validate_client_id_and_secret(self, http_client, app_config): # TODO: find a way to verify client secret too. check_auth_url = http_client.get(self.get_auth_url()) if check_auth_url.status_code // 100 != 2: raise Exception('Got non-200 status code for authorization endpoint') def requires_form_encoding(self): return True def get_public_config(self): return { 'CLIENT_ID': self.client_id(), 'OIDC': True, } def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix): # Exchange the code for the access token and id_token try: json_data = self.exchange_code(app_config, http_client, code, redirect_suffix=redirect_suffix, form_encode=self.requires_form_encoding()) except OAuthExchangeCodeException as oce: raise OAuthLoginException(oce.message) # Make sure we received both. access_token = json_data.get('access_token', None) if access_token is None: logger.debug('Missing access_token in response: %s', json_data) raise OAuthLoginException('Missing `access_token` in OIDC response') id_token = json_data.get('id_token', None) if id_token is None: logger.debug('Missing id_token in response: %s', json_data) raise OAuthLoginException('Missing `id_token` in OIDC response') # Decode the id_token. try: decoded_id_token = self._decode_user_jwt(id_token) except InvalidTokenError as ite: logger.exception('Got invalid token error on OIDC decode: %s', ite.message) raise OAuthLoginException('Could not decode OIDC token') except PublicKeyLoadException as pke: logger.exception('Could not load public key during OIDC decode: %s', pke.message) raise OAuthLoginException('Could find public OIDC key') # Retrieve the user information. try: user_info = self.get_user_info(http_client, access_token) except OAuthGetUserInfoException as oge: raise OAuthLoginException(oge.message) # Verify subs. if user_info['sub'] != decoded_id_token['sub']: raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint') # Check if we have a verified email address. email_address = user_info.get('email') if user_info.get('email_verified') else None if self._mailing: if email_address is None: raise OAuthLoginException('A verified email address is required to login with this service') # Check for a preferred username. lusername = user_info.get('preferred_username') or user_info.get('sub') return decoded_id_token['sub'], lusername, email_address @property def _issuer(self): return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER']) @lru_cache(maxsize=1) def _oidc_config(self): if self.config.get('OIDC_SERVER'): return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False)) else: return {} def _load_oidc_config_via_discovery(self, is_debugging): """ Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True, non-secure connections are alllowed. Raises an DiscoveryFailureException on failure. """ oidc_server = self.config['OIDC_SERVER'] if not oidc_server.startswith('https://') and not is_debugging: raise DiscoveryFailureException('OIDC server must be accessed over SSL') discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN) discovery = self._http_client.get(discovery_url, timeout=5, verify=is_debugging is False) if discovery.status_code // 100 != 2: logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text) raise DiscoveryFailureException("Could not load OIDC discovery information") try: return json.loads(discovery.text) except ValueError: logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) raise DiscoveryFailureException("Could not parse OIDC discovery information") def _decode_user_jwt(self, token): """ Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError exception on an invalid token or a PublicKeyLoadException if the public key could not be loaded for decoding. """ # Find the key to use. headers = jwt.get_unverified_header(token) kid = headers.get('kid', None) if kid is None: raise InvalidTokenError('Missing `kid` header') try: return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS, audience=self.client_id(), issuer=self._issuer, leeway=JWT_CLOCK_SKEW_SECONDS, options=dict(require_nbf=False)) except InvalidTokenError: # Public key may have expired. Try to retrieve an updated public key and use it to decode. return decode(token, self._get_public_key(kid, force_refresh=True), algorithms=ALLOWED_ALGORITHMS, audience=self.client_id(), issuer=self._issuer, leeway=JWT_CLOCK_SKEW_SECONDS, options=dict(require_nbf=False)) def _get_public_key(self, kid, force_refresh=False): """ Retrieves the public key for this handler with the given kid. Raises a PublicKeyLoadException on failure. """ # If force_refresh is true, we expire all the items in the cache by setting the time to # the current time + the expiration TTL. if force_refresh: self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL) # Retrieve the public key from the cache. If the cache does not contain the public key, # it will internally call _load_public_key to retrieve it and then save it. return self._public_key_cache[kid] def _load_public_key(self, kid): """ Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException on failure. """ keys_url = self._oidc_config()['jwks_uri'] # Load the keys. try: keys = KEYS() keys.load_from_url(keys_url, verify=not self.config.get('DEBUGGING', False)) except Exception as ex: logger.exception('Exception loading public key') raise PublicKeyLoadException(ex.message) # Find the matching key. keys_found = keys.by_kid(kid) if len(keys_found) == 0: raise PublicKeyLoadException('Public key %s not found' % kid) rsa_keys = [key for key in keys_found if key.kty == 'RSA'] if len(rsa_keys) == 0: raise PublicKeyLoadException('No RSA form of public key %s not found' % kid) matching_key = rsa_keys[0] matching_key.deserialize() # Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing # issues. return load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend())