diff --git a/endpoints/api/__init__.py b/endpoints/api/__init__.py index d9c3d0fb0..6a9369d8d 100644 --- a/endpoints/api/__init__.py +++ b/endpoints/api/__init__.py @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) api_bp = Blueprint('api', __name__) api = Api() api.init_app(api_bp) -api.decorators = [csrf_protect, +api.decorators = [csrf_protect(), crossdomain(origin='*', headers=['Authorization', 'Content-Type']), process_oauth, time_decorator(api_bp.name, metric_queue)] diff --git a/endpoints/api/user.py b/endpoints/api/user.py index 5457132e8..114ac7277 100644 --- a/endpoints/api/user.py +++ b/endpoints/api/user.py @@ -26,6 +26,7 @@ from endpoints.api import (ApiResource, nickname, resource, validate_json_reques from endpoints.exception import NotFound, InvalidToken from endpoints.api.subscribe import subscribe from endpoints.common import common_login +from endpoints.csrf import generate_csrf_token, OAUTH_CSRF_TOKEN_NAME from endpoints.decorators import anon_allowed from util.useremails import (send_confirmation_email, send_recovery_email, send_change_email, send_password_changed, send_org_recovery_email) @@ -673,6 +674,15 @@ class Signout(ApiResource): return {'success': True} +@resource('/v1/externaltoken') +@internal_only +class GenerateExternalToken(ApiResource): + """ Resource for generating a token for external login. """ + @nickname('generateExternalLoginToken') + def post(self): + """ Generates a CSRF token explicitly for OIDC/OAuth-associated login. """ + return {'token': generate_csrf_token(OAUTH_CSRF_TOKEN_NAME)} + @resource('/v1/detachexternal/') @show_if(features.DIRECT_LOGIN) diff --git a/endpoints/csrf.py b/endpoints/csrf.py index 39a0d636b..b2dbfcff1 100644 --- a/endpoints/csrf.py +++ b/endpoints/csrf.py @@ -1,9 +1,10 @@ import logging import os import base64 +import hmac -from flask import session, request from functools import wraps +from flask import session, request from app import app from auth.auth_context import get_validated_oauth_token @@ -12,31 +13,47 @@ from util.http import abort logger = logging.getLogger(__name__) +OAUTH_CSRF_TOKEN_NAME = '_oauth_csrf_token' +_QUAY_CSRF_TOKEN_NAME = '_csrf_token' -def generate_csrf_token(): - if '_csrf_token' not in session: - session['_csrf_token'] = base64.b64encode(os.urandom(48)) +def generate_csrf_token(session_token_name=_QUAY_CSRF_TOKEN_NAME): + """ If not present in the session, generates a new CSRF token with the given name + and places it into the session. Returns the generated token. + """ + if session_token_name not in session: + session[session_token_name] = base64.b64encode(os.urandom(48)) - return session['_csrf_token'] + return session[session_token_name] -def verify_csrf(): - token = session.get('_csrf_token', None) - found_token = request.values.get('_csrf_token', None) - if not token or token != found_token: - msg = 'CSRF Failure. Session token was %s and request token was %s' - logger.error(msg, token, found_token) +def verify_csrf(session_token_name=_QUAY_CSRF_TOKEN_NAME, + request_token_name=_QUAY_CSRF_TOKEN_NAME): + """ Verifies that the CSRF token with the given name is found in the session and + that the matching token is found in the request args or values. + """ + token = str(session.get(session_token_name, '')) + found_token = str(request.values.get(request_token_name, '')) + + if not token or not found_token or not hmac.compare_digest(token, found_token): + msg = 'CSRF Failure. Session token (%s) was %s and request token (%s) was %s' + logger.error(msg, session_token_name, token, request_token_name, found_token) abort(403, message='CSRF token was invalid or missing.') -def csrf_protect(func): - @wraps(func) - def wrapper(*args, **kwargs): - oauth_token = get_validated_oauth_token() - if oauth_token is None and request.method != "GET" and request.method != "HEAD": - verify_csrf() - return func(*args, **kwargs) - return wrapper +def csrf_protect(session_token_name=_QUAY_CSRF_TOKEN_NAME, + request_token_name=_QUAY_CSRF_TOKEN_NAME, + all_methods=False): + def inner(func): + @wraps(func) + def wrapper(*args, **kwargs): + oauth_token = get_validated_oauth_token() + if oauth_token is None: + if all_methods or (request.method != "GET" and request.method != "HEAD"): + verify_csrf(session_token_name, request_token_name) + + return func(*args, **kwargs) + return wrapper + return inner app.jinja_env.globals['csrf_token'] = generate_csrf_token diff --git a/endpoints/oauthlogin.py b/endpoints/oauthlogin.py index a750f5519..17cb6da20 100644 --- a/endpoints/oauthlogin.py +++ b/endpoints/oauthlogin.py @@ -12,6 +12,7 @@ from auth.process import require_session_login from data import model from endpoints.common import common_login, route_show_if from endpoints.web import index +from endpoints.csrf import csrf_protect, OAUTH_CSRF_TOKEN_NAME from util.security.jwtutil import decode, InvalidTokenError from util.validation import generate_valid_usernames @@ -19,6 +20,7 @@ logger = logging.getLogger(__name__) client = app.config['HTTPCLIENT'] oauthlogin = Blueprint('oauthlogin', __name__) +oauthlogin_csrf_protect = csrf_protect(OAUTH_CSRF_TOKEN_NAME, 'state', all_methods=True) def render_ologin_error(service_name, error_message=None, register_redirect=False): user_creation = bool(features.USER_CREATION and features.DIRECT_LOGIN) @@ -30,7 +32,11 @@ def render_ologin_error(service_name, error_message=None, register_redirect=Fals 'user_creation': user_creation, 'register_redirect': register_redirect, } - return index('', error_info=error_info) + + resp = index('', error_info=error_info) + resp.status_code = 400 + return resp + def get_user(service, token): token_param = { @@ -44,7 +50,7 @@ def get_user(service, token): return got_user.json() -def conduct_oauth_login(service, user_id, username, email, metadata={}): +def conduct_oauth_login(service, user_id, username, email, metadata=None): service_name = service.service_name() to_login = model.user.verify_federated_login(service_name.lower(), user_id) if not to_login: @@ -66,17 +72,12 @@ def conduct_oauth_login(service, user_id, username, email, metadata={}): prompts = model.user.get_default_user_prompts(features) to_login = model.user.create_federated_user(new_username, email, service_name.lower(), user_id, set_password_notification=True, - metadata=metadata, + metadata=metadata or {}, prompts=prompts) # Success, tell analytics analytics.track(to_login.username, 'register', {'service': service_name.lower()}) - state = request.args.get('state', None) - if state: - logger.debug('Aliasing with state: %s', state) - analytics.alias(to_login.username, state) - except model.InvalidEmailAddressException: message = "The e-mail address %s is already associated " % (email, ) message = message + "with an existing %s account." % (app.config['REGISTRY_TITLE_SHORT'], ) @@ -96,6 +97,7 @@ def conduct_oauth_login(service, user_id, username, email, metadata={}): return render_ologin_error(service_name) + def get_email_username(user_data): username = user_data['email'] at = username.find('@') @@ -107,6 +109,7 @@ def get_email_username(user_data): @oauthlogin.route('/google/callback', methods=['GET']) @route_show_if(features.GOOGLE_LOGIN) +@oauthlogin_csrf_protect def google_oauth_callback(): error = request.args.get('error', None) if error: @@ -114,6 +117,9 @@ def google_oauth_callback(): code = request.args.get('code') token = google_login.exchange_code_for_token(app.config, client, code, form_encode=True) + if token is None: + return render_ologin_error('Google') + user_data = get_user(google_login, token) if not user_data or not user_data.get('id', None) or not user_data.get('email', None): return render_ologin_error('Google') @@ -136,6 +142,7 @@ def google_oauth_callback(): @oauthlogin.route('/github/callback', methods=['GET']) @route_show_if(features.GITHUB_LOGIN) +@oauthlogin_csrf_protect def github_oauth_callback(): error = request.args.get('error', None) if error: @@ -144,10 +151,12 @@ def github_oauth_callback(): # Exchange the OAuth code. code = request.args.get('code') token = github_login.exchange_code_for_token(app.config, client, code) + if token is None: + return render_ologin_error('GitHub') # Retrieve the user's information. user_data = get_user(github_login, token) - if not user_data or not 'login' in user_data: + if not user_data or 'login' not in user_data: return render_ologin_error('GitHub') username = user_data['login'] @@ -167,7 +176,8 @@ def github_oauth_callback(): headers={'Accept': 'application/vnd.github.moondragon+json'}) organizations = set([org.get('login').lower() for org in get_orgs.json()]) - if not (organizations & set(github_login.allowed_organizations())): + matching_organizations = organizations & set(github_login.allowed_organizations()) + if not matching_organizations: err = """You are not a member of an allowed GitHub organization. Please contact your system administrator if you believe this is in error.""" return render_ologin_error('GitHub', err) @@ -175,6 +185,8 @@ def github_oauth_callback(): # Find the e-mail address for the user: we will accept any email, but we prefer the primary get_email = client.get(github_login.email_endpoint(), params=token_param, headers=v3_media_type) + if get_email.status_code / 100 != 2: + return render_ologin_error('GitHub') found_email = None for user_email in get_email.json(): @@ -199,10 +211,13 @@ def github_oauth_callback(): @oauthlogin.route('/google/callback/attach', methods=['GET']) @route_show_if(features.GOOGLE_LOGIN) @require_session_login +@oauthlogin_csrf_protect def google_oauth_attach(): code = request.args.get('code') token = google_login.exchange_code_for_token(app.config, client, code, redirect_suffix='/attach', form_encode=True) + if token is None: + return render_ologin_error('Google') user_data = get_user(google_login, token) if not user_data or not user_data.get('id', None): @@ -236,9 +251,13 @@ def google_oauth_attach(): @oauthlogin.route('/github/callback/attach', methods=['GET']) @route_show_if(features.GITHUB_LOGIN) @require_session_login +@oauthlogin_csrf_protect def github_oauth_attach(): code = request.args.get('code') token = github_login.exchange_code_for_token(app.config, client, code) + if token is None: + return render_ologin_error('GitHub') + user_data = get_user(github_login, token) if not user_data: return render_ologin_error('GitHub') @@ -276,6 +295,7 @@ def decode_user_jwt(token, oidc_provider): @oauthlogin.route('/dex/callback', methods=['GET', 'POST']) @route_show_if(features.DEX_LOGIN) +@oauthlogin_csrf_protect def dex_oauth_callback(): error = request.values.get('error', None) if error: @@ -287,6 +307,8 @@ def dex_oauth_callback(): token = dex_login.exchange_code_for_token(app.config, client, code, client_auth=True, form_encode=True) + if token is None: + return render_ologin_error(dex_login.public_title) try: payload = decode_user_jwt(token, dex_login) @@ -318,11 +340,12 @@ def dex_oauth_callback(): @oauthlogin.route('/dex/callback/attach', methods=['GET', 'POST']) @route_show_if(features.DEX_LOGIN) @require_session_login +@oauthlogin_csrf_protect def dex_oauth_attach(): code = request.args.get('code') token = dex_login.exchange_code_for_token(app.config, client, code, redirect_suffix='/attach', client_auth=True, form_encode=True) - if not token: + if token is None: return render_ologin_error(dex_login.public_title) try: @@ -346,4 +369,3 @@ def dex_oauth_attach(): return render_ologin_error(dex_login.public_title, err) return redirect(url_for('web.user_view', path=user_obj.username, tab='external')) - diff --git a/endpoints/web.py b/endpoints/web.py index 03fec46f6..f3a6f7ce7 100644 --- a/endpoints/web.py +++ b/endpoints/web.py @@ -494,7 +494,7 @@ def oauth_local_handler(): @web.route('/oauth/denyapp', methods=['POST']) -@csrf_protect +@csrf_protect() def deny_application(): if not current_user.is_authenticated: abort(401) diff --git a/static/js/app.js b/static/js/app.js index bdffb9451..ba49e6181 100644 --- a/static/js/app.js +++ b/static/js/app.js @@ -215,7 +215,8 @@ quayApp.config(['$routeProvider', '$locationProvider', 'pages', function($routeP // 404/403 .route('/:catchall', 'error-view') .route('/:catch/:all', 'error-view') - .route('/:catch/:all/:things', 'error-view'); + .route('/:catch/:all/:things', 'error-view') + .route('/:catch/:all/:things/:here', 'error-view'); }]); // Configure compile provider to add additional URL prefixes to the sanitization list. We use diff --git a/static/js/directives/ui/external-login-button.js b/static/js/directives/ui/external-login-button.js index 509b2b489..f8ec07e53 100644 --- a/static/js/directives/ui/external-login-button.js +++ b/static/js/directives/ui/external-login-button.js @@ -21,19 +21,21 @@ angular.module('quay').directive('externalLoginButton', function () { $scope.startSignin = function() { $scope.signInStarted({'service': $scope.provider}); + ApiService.generateExternalLoginToken().then(function(data) { + var url = ExternalLoginService.getLoginUrl($scope.provider, $scope.action || 'login'); + url = url + '&state=' + encodeURIComponent(data['token']); - var url = ExternalLoginService.getLoginUrl($scope.provider, $scope.action || 'login'); + // Save the redirect URL in a cookie so that we can redirect back after the service returns to us. + var redirectURL = $scope.redirectUrl || window.location.toString(); + CookieService.putPermanent('quay.redirectAfterLoad', redirectURL); - // Save the redirect URL in a cookie so that we can redirect back after the service returns to us. - var redirectURL = $scope.redirectUrl || window.location.toString(); - CookieService.putPermanent('quay.redirectAfterLoad', redirectURL); - - // Needed to ensure that UI work done by the started callback is finished before the location - // changes. - $scope.signingIn = true; - $timeout(function() { - document.location = url; - }, 250); + // Needed to ensure that UI work done by the started callback is finished before the location + // changes. + $scope.signingIn = true; + $timeout(function() { + document.location = url; + }, 250); + }, ApiService.errorDisplay('Could not perform sign in')); }; } }; diff --git a/static/js/services/external-login-service.js b/static/js/services/external-login-service.js index 91197e7b4..7d6061df6 100644 --- a/static/js/services/external-login-service.js +++ b/static/js/services/external-login-service.js @@ -9,14 +9,6 @@ angular.module('quay').factory('ExternalLoginService', ['KeyService', 'Features' var serviceInfo = externalLoginService.getProvider(service); if (!serviceInfo) { return ''; } - var stateClause = ''; - - if (Config.MIXPANEL_KEY && window.mixpanel) { - if (mixpanel.get_distinct_id !== undefined) { - stateClause = "&state=" + encodeURIComponent(mixpanel.get_distinct_id()); - } - } - var loginUrl = KeyService.getConfiguration(serviceInfo.key, 'AUTHORIZE_ENDPOINT'); var clientId = KeyService.getConfiguration(serviceInfo.key, 'CLIENT_ID'); @@ -28,8 +20,7 @@ angular.module('quay').factory('ExternalLoginService', ['KeyService', 'Features' } var url = loginUrl + 'client_id=' + clientId + '&scope=' + scope + '&redirect_uri=' + - redirectUri + stateClause; - + redirectUri; return url; }; diff --git a/test/test_api_usage.py b/test/test_api_usage.py index 1bfb8f9dc..f89003972 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -124,9 +124,11 @@ class ApiTestCase(unittest.TestCase): query[CSRF_TOKEN_KEY] = CSRF_TOKEN return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:])) - def url_for(self, resource_name, params={}): + def url_for(self, resource_name, params=None, skip_csrf=False): + params = params or {} url = api.url_for(resource_name, **params) - url = ApiTestCase._add_csrf(url) + if not skip_csrf: + url = ApiTestCase._add_csrf(url) return url def setUp(self): @@ -211,8 +213,8 @@ class ApiTestCase(unittest.TestCase): return parsed def putJsonResponse(self, resource_name, params={}, data={}, - expected_code=200): - rv = self.app.put(self.url_for(resource_name, params), + expected_code=200, skip_csrf=False): + rv = self.app.put(self.url_for(resource_name, params, skip_csrf), data=py_json.dumps(data), headers={"Content-Type": "application/json"}) @@ -246,15 +248,35 @@ class TestCSRFFailure(ApiTestCase): self.login(READ_ACCESS_USER) # Make sure a simple post call succeeds. - self.putJsonResponse(User, - data=dict(password='newpasswordiscool')) + self.putJsonResponse(User, data=dict(password='newpasswordiscool')) # Change the session's CSRF token. self.setCsrfToken('someinvalidtoken') # Verify that the call now fails. - self.putJsonResponse(User, - data=dict(password='newpasswordiscool'), + self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403) + + def test_csrf_failure_empty_token(self): + self.login(READ_ACCESS_USER) + + # Change the session's CSRF token to be empty. + self.setCsrfToken('') + + # Verify that the call now fails. + self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403) + + def test_csrf_failure_missing_token(self): + self.login(READ_ACCESS_USER) + + # Make sure a simple post call without a token at all fails. + self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True, + expected_code=403) + + # Change the session's CSRF token to be empty. + self.setCsrfToken('') + + # Verify that the call still fails. + self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True, expected_code=403) diff --git a/test/test_endpoints.py b/test/test_endpoints.py index b29db6af5..e520b3b37 100644 --- a/test/test_endpoints.py +++ b/test/test_endpoints.py @@ -8,6 +8,7 @@ import base64 from urllib import urlencode from urlparse import urlparse, urlunparse, parse_qs from datetime import datetime, timedelta +from httmock import urlmatch, HTTMock import jwt @@ -22,13 +23,25 @@ from endpoints import keyserver from endpoints.api import api, api_bp from endpoints.api.user import Signin from endpoints.keyserver import jwk_with_kid +from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME from endpoints.web import web as web_bp +from endpoints.oauthlogin import oauthlogin as oauthlogin_bp from initdb import setup_database_for_testing, finished_database_for_testing from test.helpers import assert_action_logged +try: + app.register_blueprint(oauthlogin_bp, url_prefix='/oauth') +except ValueError: + # This blueprint was already registered + pass try: app.register_blueprint(web_bp, url_prefix='') +except ValueError: + # This blueprint was already registered + pass + +try: app.register_blueprint(keyserver.key_server, url_prefix='') except ValueError: # This blueprint was already registered @@ -69,6 +82,7 @@ class EndpointTestCase(unittest.TestCase): def setCsrfToken(self, token): with self.app.session_transaction() as sess: sess[CSRF_TOKEN_KEY] = token + sess[OAUTH_CSRF_TOKEN_NAME] = 'someoauthtoken' def getResponse(self, resource_name, expected_code=200, **kwargs): rv = self.app.get(url_for(resource_name, **kwargs)) @@ -108,6 +122,140 @@ class EndpointTestCase(unittest.TestCase): self.assertEquals(rv.status_code, 200) +class OAuthLoginTestCase(EndpointTestCase): + def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident, + new_username): + # Test callback. + created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username) + + # Delete the created user. + model.user.delete_user(created, []) + + # Test attach. + self.login('devtable', 'password') + self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable') + + def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username): + # No CSRF. + self.getResponse('oauthlogin.' + endpoint_name, expected_code=403) + + # Invalid CSRF. + self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403) + + # Valid CSRF, invalid code. + self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken', + code='invalidcode', expected_code=400) + + # Valid CSRF, valid code. + self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken', + code='somecode', expected_code=302) + + # Ensure the user was added/modified. + found_user = model.user.get_user(username) + self.assertIsNotNone(found_user) + + federated_login = model.user.lookup_federated_login(found_user, service_name) + self.assertIsNotNone(federated_login) + self.assertEquals(federated_login.service_ident, service_ident) + return found_user + + def test_google_oauth(self): + @urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token') + def account_handler(_, request): + if request.body.find("code=somecode") > 0: + content = {'access_token': 'someaccesstoken'} + return py_json.dumps(content) + else: + return {'status_code': 400, 'content': '{"message": "Invalid code"}'} + + @urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo') + def user_handler(_, __): + content = { + 'id': 'someid', + 'email': 'someemail@example.com', + 'verified_email': True, + } + return py_json.dumps(content) + + with HTTMock(account_handler, user_handler): + self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google', + 'someid', 'someemail') + + def test_github_oauth(self): + @urlmatch(netloc=r'github.com', path='/login/oauth/access_token') + def account_handler(url, _): + if url.query.find("code=somecode") > 0: + content = {'access_token': 'someaccesstoken'} + return py_json.dumps(content) + else: + return {'status_code': 400, 'content': '{"message": "Invalid code"}'} + + @urlmatch(netloc=r'github.com', path='/api/v3/user') + def user_handler(_, __): + content = { + 'id': 'someid', + 'login': 'someusername' + } + return py_json.dumps(content) + + @urlmatch(netloc=r'github.com', path='/api/v3/user/emails') + def email_handler(_, __): + content = [{ + 'email': 'someemail@example.com', + 'verified': True, + 'primary': True, + }] + return py_json.dumps(content) + + with HTTMock(account_handler, email_handler, user_handler): + self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github', + 'someid', 'someusername') + + def test_dex_oauth(self): + # TODO(jschorr): Add tests for invalid and expired keys. + + # Generate a public/private key pair for the OIDC transaction. + private_key = RSA.generate(2048) + jwk = RSAKey(key=private_key.publickey()).serialize() + token = jwt.encode({ + 'iss': 'https://oidcserver/', + 'aud': 'someclientid', + 'sub': 'someid', + 'exp': int(time.time()) + 60, + 'iat': int(time.time()), + 'nbf': int(time.time()), + 'email': 'someemail@example.com', + 'email_verified': True, + }, private_key.exportKey('PEM'), 'RS256') + + @urlmatch(netloc=r'oidcserver', path='/.well-known/openid-configuration') + def wellknown_handler(url, _): + return py_json.dumps({ + 'authorization_endpoint': 'http://oidcserver/auth', + 'token_endpoint': 'http://oidcserver/token', + 'jwks_uri': 'http://oidcserver/keys', + }) + + @urlmatch(netloc=r'oidcserver', path='/token') + def account_handler(url, request): + if request.body.find("code=somecode") > 0: + return py_json.dumps({ + 'access_token': token, + }) + else: + return {'status_code': 400, 'content': '{"message": "Invalid code"}'} + + @urlmatch(netloc=r'oidcserver', path='/keys') + def keys_handler(_, __): + return py_json.dumps({ + "keys": [jwk], + }) + + with HTTMock(wellknown_handler, account_handler, keys_handler): + self.invoke_oauth_tests('dex_oauth_callback', 'dex_oauth_attach', 'dex', + 'someid', 'someemail') + + class WebEndpointTestCase(EndpointTestCase): def test_index(self): self.getResponse('web.index') diff --git a/test/testconfig.py b/test/testconfig.py index aee1b902f..02c678200 100644 --- a/test/testconfig.py +++ b/test/testconfig.py @@ -75,3 +75,12 @@ class TestConfig(DefaultConfig): INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem' PROMETHEUS_AGGREGATOR_URL = None + + FEATURE_GITHUB_LOGIN = True + FEATURE_GOOGLE_LOGIN = True + FEATURE_DEX_LOGIN = True + + DEX_LOGIN_CONFIG = { + 'CLIENT_ID': 'someclientid', + 'OIDC_SERVER': 'https://oidcserver/', + } diff --git a/util/config/oauth.py b/util/config/oauth.py index 6b1ef9b9b..44bf084f2 100644 --- a/util/config/oauth.py +++ b/util/config/oauth.py @@ -4,6 +4,11 @@ import logging import time from cachetools import TTLCache +from cachetools.func import lru_cache + +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import load_der_public_key + from jwkest.jwk import KEYS from util import slash_join @@ -64,12 +69,14 @@ class OAuthConfig(object): else: get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth) + if get_access_token.status_code / 100 != 2: + return None + json_data = get_access_token.json() if not json_data: - return '' + return None - token = json_data.get('access_token', '') - return token + return json_data.get('access_token', None) class GithubOAuthConfig(OAuthConfig): @@ -265,11 +272,15 @@ class OIDCConfig(OAuthConfig): super(OIDCConfig, self).__init__(config, key_name) self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key) - self._oidc_config = {} + self._config = config self._http_client = config['HTTPCLIENT'] + @lru_cache(maxsize=1) + def _oidc_config(self): if self.config.get('OIDC_SERVER'): - self._load_via_discovery(config.get('DEBUGGING', False)) + return self._load_via_discovery(self._config.get('DEBUGGING', False)) + else: + return {} def _load_via_discovery(self, is_debugging): oidc_server = self.config['OIDC_SERVER'] @@ -283,16 +294,16 @@ class OIDCConfig(OAuthConfig): raise Exception("Could not load OIDC discovery information") try: - self._oidc_config = json.loads(discovery.text) + return json.loads(discovery.text) except ValueError: logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) raise Exception("Could not parse OIDC discovery information") def authorize_endpoint(self): - return self._oidc_config.get('authorization_endpoint', '') + '?' + return self._oidc_config().get('authorization_endpoint', '') + '?' def token_endpoint(self): - return self._oidc_config.get('token_endpoint') + return self._oidc_config().get('token_endpoint') def user_endpoint(self): return None @@ -322,9 +333,9 @@ class OIDCConfig(OAuthConfig): # a random key chose to be stored in the cache, and could be anything. return self._public_key_cache[None] - def _get_public_key(self): + def _get_public_key(self, _): """ Retrieves the public key for this handler. """ - keys_url = self._oidc_config['jwks_uri'] + keys_url = self._oidc_config()['jwks_uri'] keys = KEYS() keys.load_from_url(keys_url) @@ -334,7 +345,10 @@ class OIDCConfig(OAuthConfig): rsa_key = list(keys)[0] rsa_key.deserialize() - return rsa_key.key.exportKey('PEM') + + # Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing + # issues. + return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend()) class DexOAuthConfig(OIDCConfig):