Add proper and tested OIDC support on the server

Note that this will still not work on the client side; the followup CL for the client side is right after this one.
This commit is contained in:
Joseph Schorr 2017-01-23 17:53:34 -05:00
parent 19f7acf575
commit fda203e4d7
15 changed files with 756 additions and 180 deletions

View file

@ -3,108 +3,244 @@ import json
import logging
import urlparse
import jwt
from cachetools import lru_cache
from cachetools.ttl import TTLCache
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key
from jwkest.jwk import KEYS
from util.oauth.base import OAuthService
from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException
from oauth.login import OAuthLoginException
from util.security.jwtutil import decode, InvalidTokenError
from util import get_app_url
logger = logging.getLogger(__name__)
def decode_user_jwt(token, oidc_provider):
try:
return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'],
audience=oidc_provider.client_id(),
issuer=oidc_provider.issuer)
except InvalidTokenError:
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'],
audience=oidc_provider.client_id(),
issuer=oidc_provider.issuer)
OIDC_WELLKNOWN = ".well-known/openid-configuration"
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
ALLOWED_ALGORITHMS = ['RS256']
JWT_CLOCK_SKEW_SECONDS = 30
class OIDCConfig(OAuthService):
class DiscoveryFailureException(Exception):
""" Exception raised when OIDC discovery fails. """
pass
class PublicKeyLoadException(Exception):
""" Exception raised if loading the OIDC public key fails. """
pass
class OIDCLoginService(OAuthService):
""" Defines a generic service for all OpenID-connect compatible login services. """
def __init__(self, config, key_name):
super(OIDCConfig, self).__init__(config, key_name)
super(OIDCLoginService, self).__init__(config, key_name)
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
self._config = config
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._load_public_key)
self._id = key_name[0:key_name.find('_')].lower()
self._http_client = config['HTTPCLIENT']
self._mailing = config.get('FEATURE_MAILING', False)
@lru_cache(maxsize=1)
def _oidc_config(self):
if self.config.get('OIDC_SERVER'):
return self._load_via_discovery(self._config.get('DEBUGGING', False))
else:
return {}
def service_id(self):
return self._id
def _load_via_discovery(self, is_debugging):
oidc_server = self.config['OIDC_SERVER']
if not oidc_server.startswith('https://') and not is_debugging:
raise Exception('OIDC server must be accessed over SSL')
def service_name(self):
return self.config.get('SERVICE_NAME', self.service_id())
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
discovery = self._http_client.get(discovery_url, timeout=5)
def get_icon(self):
return self.config.get('SERVICE_ICON', 'fa-user-circle')
if discovery.status_code / 100 != 2:
raise Exception("Could not load OIDC discovery information")
def get_login_scopes(self):
default_scopes = ['openid']
try:
return json.loads(discovery.text)
except ValueError:
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
raise Exception("Could not parse OIDC discovery information")
if self.user_endpoint() is not None:
default_scopes.append('profile')
if self._mailing:
default_scopes.append('email')
return self._oidc_config().get('scopes_supported', default_scopes)
def authorize_endpoint(self):
return self._oidc_config().get('authorization_endpoint', '') + '?'
return self._oidc_config().get('authorization_endpoint', '') + '?response_type=code&'
def token_endpoint(self):
return self._oidc_config().get('token_endpoint')
def user_endpoint(self):
return None
return self._oidc_config().get('userinfo_endpoint')
def validate_client_id_and_secret(self, http_client, app_config):
pass
# TODO: find a way to verify client secret too.
redirect_url = '%s/oauth2/%s/callback' % (get_app_url(app_config), self.service_id())
scopes_string = ' '.join(self.get_login_scopes())
authorize_url = '%sclient_id=%s&redirect_uri=%s&scope=%s' % (self.authorize_endpoint(),
self.client_id(),
redirect_url,
scopes_string)
check_auth_url = http_client.get(authorize_url)
if check_auth_url.status_code // 100 != 2:
raise Exception('Got non-200 status code for authorization endpoint')
def requires_form_encoding(self):
return True
def get_public_config(self):
return {
'CLIENT_ID': self.client_id(),
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
'OIDC': True,
}
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
# Exchange the code for the access token and id_token
try:
json_data = self.exchange_code(app_config, http_client, code,
redirect_suffix=redirect_suffix,
form_encode=self.requires_form_encoding())
except OAuthExchangeCodeException as oce:
raise OAuthLoginException(oce.message)
# Make sure we received both.
access_token = json_data.get('access_token', None)
if access_token is None:
logger.debug('Missing access_token in response: %s', json_data)
raise OAuthLoginException('Missing `access_token` in OIDC response')
id_token = json_data.get('id_token', None)
if id_token is None:
logger.debug('Missing id_token in response: %s', json_data)
raise OAuthLoginException('Missing `id_token` in OIDC response')
# Decode the id_token.
try:
decoded_id_token = self._decode_user_jwt(id_token)
except InvalidTokenError as ite:
logger.exception('Got invalid token error on OIDC decode: %s', ite.message)
raise OAuthLoginException('Could not decode OIDC token')
except PublicKeyLoadException as pke:
logger.exception('Could not load public key during OIDC decode: %s', pke.message)
raise OAuthLoginException('Could find public OIDC key')
# Retrieve the user information.
try:
user_info = self.get_user_info(http_client, access_token)
except OAuthGetUserInfoException as oge:
raise OAuthLoginException(oge.message)
# Verify subs.
if user_info['sub'] != decoded_id_token['sub']:
raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint')
# Check if we have a verified email address.
email_address = user_info.get('email') if user_info.get('email_verified') else None
if self._mailing:
if email_address is None:
raise OAuthLoginException('A verified email address is required to login with this service')
# Check for a preferred username.
lusername = user_info.get('preferred_username') or user_info.get('sub')
return decoded_id_token['sub'], lusername, email_address
@property
def issuer(self):
def _issuer(self):
return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
def get_public_key(self, force_refresh=False):
""" Retrieves the public key for this handler. """
@lru_cache(maxsize=1)
def _oidc_config(self):
if self.config.get('OIDC_SERVER'):
return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False))
else:
return {}
def _load_oidc_config_via_discovery(self, is_debugging):
""" Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True,
non-secure connections are alllowed. Raises an DiscoveryFailureException on failure.
"""
oidc_server = self.config['OIDC_SERVER']
if not oidc_server.startswith('https://') and not is_debugging:
raise DiscoveryFailureException('OIDC server must be accessed over SSL')
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
discovery = self._http_client.get(discovery_url, timeout=5, verify=not is_debugging)
if discovery.status_code // 100 != 2:
logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text)
raise DiscoveryFailureException("Could not load OIDC discovery information")
try:
return json.loads(discovery.text)
except ValueError:
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
raise DiscoveryFailureException("Could not parse OIDC discovery information")
def _decode_user_jwt(self, token):
""" Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError
exception on an invalid token or a PublicKeyLoadException if the public key could not be
loaded for decoding.
"""
# Find the key to use.
headers = jwt.get_unverified_header(token)
kid = headers.get('kid', None)
if kid is None:
raise InvalidTokenError('Missing `kid` header')
try:
return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False))
except InvalidTokenError:
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
return decode(token, self._get_public_key(kid, force_refresh=True),
algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False))
def _get_public_key(self, kid, force_refresh=False):
""" Retrieves the public key for this handler with the given kid. Raises a
PublicKeyLoadException on failure. """
# If force_refresh is true, we expire all the items in the cache by setting the time to
# the current time + the expiration TTL.
if force_refresh:
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
# Retrieve the public key from the cache. If the cache does not contain the public key,
# it will internally call _get_public_key to retrieve it and then save it. The None is
# a random key chose to be stored in the cache, and could be anything.
return self._public_key_cache[None]
# it will internally call _load_public_key to retrieve it and then save it.
return self._public_key_cache[kid]
def _get_public_key(self, _):
""" Retrieves the public key for this handler. """
def _load_public_key(self, kid):
""" Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException
on failure.
"""
keys_url = self._oidc_config()['jwks_uri']
keys = KEYS()
keys.load_from_url(keys_url)
# Load the keys.
try:
keys = KEYS()
keys.load_from_url(keys_url, verify=not self.config.get('DEBUGGING', False))
except Exception as ex:
logger.exception('Exception loading public key')
raise PublicKeyLoadException(ex.message)
if not list(keys):
raise Exception('No keys provided by OIDC provider')
# Find the matching key.
keys_found = keys.by_kid(kid)
if len(keys_found) == 0:
raise PublicKeyLoadException('Public key %s not found' % kid)
rsa_key = list(keys)[0]
rsa_key.deserialize()
rsa_keys = [key for key in keys_found if key.kty == 'RSA']
if len(rsa_keys) == 0:
raise PublicKeyLoadException('No RSA form of public key %s not found' % kid)
matching_key = rsa_keys[0]
matching_key.deserialize()
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
# issues.
return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())
return load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend())