This repository has been archived on 2020-03-24. You can view files and clone it, but cannot push or open issues or pull requests.
quay/oauth/oidc.py
Joseph Schorr a9791ea419 Have external login always make an API request to get the authorization URL
This makes the OIDC lookup lazy, ensuring that the rest of the registry and app continues working even if one OIDC provider goes down.
2017-01-23 19:06:19 -05:00

237 lines
8.9 KiB
Python

import time
import json
import logging
import urlparse
import jwt
from cachetools import lru_cache
from cachetools.ttl import TTLCache
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key
from jwkest.jwk import KEYS
from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException
from oauth.login import OAuthLoginException
from util.security.jwtutil import decode, InvalidTokenError
logger = logging.getLogger(__name__)
OIDC_WELLKNOWN = ".well-known/openid-configuration"
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
ALLOWED_ALGORITHMS = ['RS256']
JWT_CLOCK_SKEW_SECONDS = 30
class DiscoveryFailureException(Exception):
""" Exception raised when OIDC discovery fails. """
pass
class PublicKeyLoadException(Exception):
""" Exception raised if loading the OIDC public key fails. """
pass
class OIDCLoginService(OAuthService):
""" Defines a generic service for all OpenID-connect compatible login services. """
def __init__(self, config, key_name):
super(OIDCLoginService, self).__init__(config, key_name)
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._load_public_key)
self._id = key_name[0:key_name.find('_')].lower()
self._http_client = config['HTTPCLIENT']
self._mailing = config.get('FEATURE_MAILING', False)
def service_id(self):
return self._id
def service_name(self):
return self.config.get('SERVICE_NAME', self.service_id())
def get_icon(self):
return self.config.get('SERVICE_ICON', 'fa-user-circle')
def get_login_scopes(self):
default_scopes = ['openid']
if self.user_endpoint() is not None:
default_scopes.append('profile')
if self._mailing:
default_scopes.append('email')
return self._oidc_config().get('scopes_supported', default_scopes)
def authorize_endpoint(self):
return self._oidc_config().get('authorization_endpoint', '') + '?response_type=code&'
def token_endpoint(self):
return self._oidc_config().get('token_endpoint')
def user_endpoint(self):
return self._oidc_config().get('userinfo_endpoint')
def validate_client_id_and_secret(self, http_client, app_config):
# TODO: find a way to verify client secret too.
check_auth_url = http_client.get(self.get_auth_url())
if check_auth_url.status_code // 100 != 2:
raise Exception('Got non-200 status code for authorization endpoint')
def requires_form_encoding(self):
return True
def get_public_config(self):
return {
'CLIENT_ID': self.client_id(),
'OIDC': True,
}
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
# Exchange the code for the access token and id_token
try:
json_data = self.exchange_code(app_config, http_client, code,
redirect_suffix=redirect_suffix,
form_encode=self.requires_form_encoding())
except OAuthExchangeCodeException as oce:
raise OAuthLoginException(oce.message)
# Make sure we received both.
access_token = json_data.get('access_token', None)
if access_token is None:
logger.debug('Missing access_token in response: %s', json_data)
raise OAuthLoginException('Missing `access_token` in OIDC response')
id_token = json_data.get('id_token', None)
if id_token is None:
logger.debug('Missing id_token in response: %s', json_data)
raise OAuthLoginException('Missing `id_token` in OIDC response')
# Decode the id_token.
try:
decoded_id_token = self._decode_user_jwt(id_token)
except InvalidTokenError as ite:
logger.exception('Got invalid token error on OIDC decode: %s', ite.message)
raise OAuthLoginException('Could not decode OIDC token')
except PublicKeyLoadException as pke:
logger.exception('Could not load public key during OIDC decode: %s', pke.message)
raise OAuthLoginException('Could find public OIDC key')
# Retrieve the user information.
try:
user_info = self.get_user_info(http_client, access_token)
except OAuthGetUserInfoException as oge:
raise OAuthLoginException(oge.message)
# Verify subs.
if user_info['sub'] != decoded_id_token['sub']:
raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint')
# Check if we have a verified email address.
email_address = user_info.get('email') if user_info.get('email_verified') else None
if self._mailing:
if email_address is None:
raise OAuthLoginException('A verified email address is required to login with this service')
# Check for a preferred username.
lusername = user_info.get('preferred_username') or user_info.get('sub')
return decoded_id_token['sub'], lusername, email_address
@property
def _issuer(self):
return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
@lru_cache(maxsize=1)
def _oidc_config(self):
if self.config.get('OIDC_SERVER'):
return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False))
else:
return {}
def _load_oidc_config_via_discovery(self, is_debugging):
""" Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True,
non-secure connections are alllowed. Raises an DiscoveryFailureException on failure.
"""
oidc_server = self.config['OIDC_SERVER']
if not oidc_server.startswith('https://') and not is_debugging:
raise DiscoveryFailureException('OIDC server must be accessed over SSL')
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
discovery = self._http_client.get(discovery_url, timeout=5, verify=not is_debugging)
if discovery.status_code // 100 != 2:
logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text)
raise DiscoveryFailureException("Could not load OIDC discovery information")
try:
return json.loads(discovery.text)
except ValueError:
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
raise DiscoveryFailureException("Could not parse OIDC discovery information")
def _decode_user_jwt(self, token):
""" Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError
exception on an invalid token or a PublicKeyLoadException if the public key could not be
loaded for decoding.
"""
# Find the key to use.
headers = jwt.get_unverified_header(token)
kid = headers.get('kid', None)
if kid is None:
raise InvalidTokenError('Missing `kid` header')
try:
return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False))
except InvalidTokenError:
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
return decode(token, self._get_public_key(kid, force_refresh=True),
algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False))
def _get_public_key(self, kid, force_refresh=False):
""" Retrieves the public key for this handler with the given kid. Raises a
PublicKeyLoadException on failure. """
# If force_refresh is true, we expire all the items in the cache by setting the time to
# the current time + the expiration TTL.
if force_refresh:
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
# Retrieve the public key from the cache. If the cache does not contain the public key,
# it will internally call _load_public_key to retrieve it and then save it.
return self._public_key_cache[kid]
def _load_public_key(self, kid):
""" Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException
on failure.
"""
keys_url = self._oidc_config()['jwks_uri']
# Load the keys.
try:
keys = KEYS()
keys.load_from_url(keys_url, verify=not self.config.get('DEBUGGING', False))
except Exception as ex:
logger.exception('Exception loading public key')
raise PublicKeyLoadException(ex.message)
# Find the matching key.
keys_found = keys.by_kid(kid)
if len(keys_found) == 0:
raise PublicKeyLoadException('Public key %s not found' % kid)
rsa_keys = [key for key in keys_found if key.kty == 'RSA']
if len(rsa_keys) == 0:
raise PublicKeyLoadException('No RSA form of public key %s not found' % kid)
matching_key = rsa_keys[0]
matching_key.deserialize()
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
# issues.
return load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend())