Merge pull request #2224 from coreos-inc/oauth-state

Have Quay always use an OAuth-specific CSRF token
This commit is contained in:
josephschorr 2016-12-09 15:16:01 -05:00 committed by GitHub
commit 648fed769b
12 changed files with 310 additions and 74 deletions

View file

@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
api_bp = Blueprint('api', __name__) api_bp = Blueprint('api', __name__)
api = Api() api = Api()
api.init_app(api_bp) api.init_app(api_bp)
api.decorators = [csrf_protect, api.decorators = [csrf_protect(),
crossdomain(origin='*', headers=['Authorization', 'Content-Type']), crossdomain(origin='*', headers=['Authorization', 'Content-Type']),
process_oauth, time_decorator(api_bp.name, metric_queue)] process_oauth, time_decorator(api_bp.name, metric_queue)]

View file

@ -26,6 +26,7 @@ from endpoints.api import (ApiResource, nickname, resource, validate_json_reques
from endpoints.exception import NotFound, InvalidToken from endpoints.exception import NotFound, InvalidToken
from endpoints.api.subscribe import subscribe from endpoints.api.subscribe import subscribe
from endpoints.common import common_login from endpoints.common import common_login
from endpoints.csrf import generate_csrf_token, OAUTH_CSRF_TOKEN_NAME
from endpoints.decorators import anon_allowed from endpoints.decorators import anon_allowed
from util.useremails import (send_confirmation_email, send_recovery_email, send_change_email, from util.useremails import (send_confirmation_email, send_recovery_email, send_change_email,
send_password_changed, send_org_recovery_email) send_password_changed, send_org_recovery_email)
@ -673,6 +674,15 @@ class Signout(ApiResource):
return {'success': True} return {'success': True}
@resource('/v1/externaltoken')
@internal_only
class GenerateExternalToken(ApiResource):
""" Resource for generating a token for external login. """
@nickname('generateExternalLoginToken')
def post(self):
""" Generates a CSRF token explicitly for OIDC/OAuth-associated login. """
return {'token': generate_csrf_token(OAUTH_CSRF_TOKEN_NAME)}
@resource('/v1/detachexternal/<servicename>') @resource('/v1/detachexternal/<servicename>')
@show_if(features.DIRECT_LOGIN) @show_if(features.DIRECT_LOGIN)

View file

@ -1,9 +1,10 @@
import logging import logging
import os import os
import base64 import base64
import hmac
from flask import session, request
from functools import wraps from functools import wraps
from flask import session, request
from app import app from app import app
from auth.auth_context import get_validated_oauth_token from auth.auth_context import get_validated_oauth_token
@ -12,31 +13,47 @@ from util.http import abort
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
OAUTH_CSRF_TOKEN_NAME = '_oauth_csrf_token'
_QUAY_CSRF_TOKEN_NAME = '_csrf_token'
def generate_csrf_token(): def generate_csrf_token(session_token_name=_QUAY_CSRF_TOKEN_NAME):
if '_csrf_token' not in session: """ If not present in the session, generates a new CSRF token with the given name
session['_csrf_token'] = base64.b64encode(os.urandom(48)) and places it into the session. Returns the generated token.
"""
if session_token_name not in session:
session[session_token_name] = base64.b64encode(os.urandom(48))
return session['_csrf_token'] return session[session_token_name]
def verify_csrf():
token = session.get('_csrf_token', None)
found_token = request.values.get('_csrf_token', None)
if not token or token != found_token: def verify_csrf(session_token_name=_QUAY_CSRF_TOKEN_NAME,
msg = 'CSRF Failure. Session token was %s and request token was %s' request_token_name=_QUAY_CSRF_TOKEN_NAME):
logger.error(msg, token, found_token) """ Verifies that the CSRF token with the given name is found in the session and
that the matching token is found in the request args or values.
"""
token = str(session.get(session_token_name, ''))
found_token = str(request.values.get(request_token_name, ''))
if not token or not found_token or not hmac.compare_digest(token, found_token):
msg = 'CSRF Failure. Session token (%s) was %s and request token (%s) was %s'
logger.error(msg, session_token_name, token, request_token_name, found_token)
abort(403, message='CSRF token was invalid or missing.') abort(403, message='CSRF token was invalid or missing.')
def csrf_protect(func):
@wraps(func)
def wrapper(*args, **kwargs):
oauth_token = get_validated_oauth_token()
if oauth_token is None and request.method != "GET" and request.method != "HEAD":
verify_csrf()
return func(*args, **kwargs) def csrf_protect(session_token_name=_QUAY_CSRF_TOKEN_NAME,
return wrapper request_token_name=_QUAY_CSRF_TOKEN_NAME,
all_methods=False):
def inner(func):
@wraps(func)
def wrapper(*args, **kwargs):
oauth_token = get_validated_oauth_token()
if oauth_token is None:
if all_methods or (request.method != "GET" and request.method != "HEAD"):
verify_csrf(session_token_name, request_token_name)
return func(*args, **kwargs)
return wrapper
return inner
app.jinja_env.globals['csrf_token'] = generate_csrf_token app.jinja_env.globals['csrf_token'] = generate_csrf_token

View file

@ -12,6 +12,7 @@ from auth.process import require_session_login
from data import model from data import model
from endpoints.common import common_login, route_show_if from endpoints.common import common_login, route_show_if
from endpoints.web import index from endpoints.web import index
from endpoints.csrf import csrf_protect, OAUTH_CSRF_TOKEN_NAME
from util.security.jwtutil import decode, InvalidTokenError from util.security.jwtutil import decode, InvalidTokenError
from util.validation import generate_valid_usernames from util.validation import generate_valid_usernames
@ -19,6 +20,7 @@ logger = logging.getLogger(__name__)
client = app.config['HTTPCLIENT'] client = app.config['HTTPCLIENT']
oauthlogin = Blueprint('oauthlogin', __name__) oauthlogin = Blueprint('oauthlogin', __name__)
oauthlogin_csrf_protect = csrf_protect(OAUTH_CSRF_TOKEN_NAME, 'state', all_methods=True)
def render_ologin_error(service_name, error_message=None, register_redirect=False): def render_ologin_error(service_name, error_message=None, register_redirect=False):
user_creation = bool(features.USER_CREATION and features.DIRECT_LOGIN) user_creation = bool(features.USER_CREATION and features.DIRECT_LOGIN)
@ -30,7 +32,11 @@ def render_ologin_error(service_name, error_message=None, register_redirect=Fals
'user_creation': user_creation, 'user_creation': user_creation,
'register_redirect': register_redirect, 'register_redirect': register_redirect,
} }
return index('', error_info=error_info)
resp = index('', error_info=error_info)
resp.status_code = 400
return resp
def get_user(service, token): def get_user(service, token):
token_param = { token_param = {
@ -44,7 +50,7 @@ def get_user(service, token):
return got_user.json() return got_user.json()
def conduct_oauth_login(service, user_id, username, email, metadata={}): def conduct_oauth_login(service, user_id, username, email, metadata=None):
service_name = service.service_name() service_name = service.service_name()
to_login = model.user.verify_federated_login(service_name.lower(), user_id) to_login = model.user.verify_federated_login(service_name.lower(), user_id)
if not to_login: if not to_login:
@ -66,17 +72,12 @@ def conduct_oauth_login(service, user_id, username, email, metadata={}):
prompts = model.user.get_default_user_prompts(features) prompts = model.user.get_default_user_prompts(features)
to_login = model.user.create_federated_user(new_username, email, service_name.lower(), to_login = model.user.create_federated_user(new_username, email, service_name.lower(),
user_id, set_password_notification=True, user_id, set_password_notification=True,
metadata=metadata, metadata=metadata or {},
prompts=prompts) prompts=prompts)
# Success, tell analytics # Success, tell analytics
analytics.track(to_login.username, 'register', {'service': service_name.lower()}) analytics.track(to_login.username, 'register', {'service': service_name.lower()})
state = request.args.get('state', None)
if state:
logger.debug('Aliasing with state: %s', state)
analytics.alias(to_login.username, state)
except model.InvalidEmailAddressException: except model.InvalidEmailAddressException:
message = "The e-mail address %s is already associated " % (email, ) message = "The e-mail address %s is already associated " % (email, )
message = message + "with an existing %s account." % (app.config['REGISTRY_TITLE_SHORT'], ) message = message + "with an existing %s account." % (app.config['REGISTRY_TITLE_SHORT'], )
@ -96,6 +97,7 @@ def conduct_oauth_login(service, user_id, username, email, metadata={}):
return render_ologin_error(service_name) return render_ologin_error(service_name)
def get_email_username(user_data): def get_email_username(user_data):
username = user_data['email'] username = user_data['email']
at = username.find('@') at = username.find('@')
@ -107,6 +109,7 @@ def get_email_username(user_data):
@oauthlogin.route('/google/callback', methods=['GET']) @oauthlogin.route('/google/callback', methods=['GET'])
@route_show_if(features.GOOGLE_LOGIN) @route_show_if(features.GOOGLE_LOGIN)
@oauthlogin_csrf_protect
def google_oauth_callback(): def google_oauth_callback():
error = request.args.get('error', None) error = request.args.get('error', None)
if error: if error:
@ -114,6 +117,9 @@ def google_oauth_callback():
code = request.args.get('code') code = request.args.get('code')
token = google_login.exchange_code_for_token(app.config, client, code, form_encode=True) token = google_login.exchange_code_for_token(app.config, client, code, form_encode=True)
if token is None:
return render_ologin_error('Google')
user_data = get_user(google_login, token) user_data = get_user(google_login, token)
if not user_data or not user_data.get('id', None) or not user_data.get('email', None): if not user_data or not user_data.get('id', None) or not user_data.get('email', None):
return render_ologin_error('Google') return render_ologin_error('Google')
@ -136,6 +142,7 @@ def google_oauth_callback():
@oauthlogin.route('/github/callback', methods=['GET']) @oauthlogin.route('/github/callback', methods=['GET'])
@route_show_if(features.GITHUB_LOGIN) @route_show_if(features.GITHUB_LOGIN)
@oauthlogin_csrf_protect
def github_oauth_callback(): def github_oauth_callback():
error = request.args.get('error', None) error = request.args.get('error', None)
if error: if error:
@ -144,10 +151,12 @@ def github_oauth_callback():
# Exchange the OAuth code. # Exchange the OAuth code.
code = request.args.get('code') code = request.args.get('code')
token = github_login.exchange_code_for_token(app.config, client, code) token = github_login.exchange_code_for_token(app.config, client, code)
if token is None:
return render_ologin_error('GitHub')
# Retrieve the user's information. # Retrieve the user's information.
user_data = get_user(github_login, token) user_data = get_user(github_login, token)
if not user_data or not 'login' in user_data: if not user_data or 'login' not in user_data:
return render_ologin_error('GitHub') return render_ologin_error('GitHub')
username = user_data['login'] username = user_data['login']
@ -167,7 +176,8 @@ def github_oauth_callback():
headers={'Accept': 'application/vnd.github.moondragon+json'}) headers={'Accept': 'application/vnd.github.moondragon+json'})
organizations = set([org.get('login').lower() for org in get_orgs.json()]) organizations = set([org.get('login').lower() for org in get_orgs.json()])
if not (organizations & set(github_login.allowed_organizations())): matching_organizations = organizations & set(github_login.allowed_organizations())
if not matching_organizations:
err = """You are not a member of an allowed GitHub organization. err = """You are not a member of an allowed GitHub organization.
Please contact your system administrator if you believe this is in error.""" Please contact your system administrator if you believe this is in error."""
return render_ologin_error('GitHub', err) return render_ologin_error('GitHub', err)
@ -175,6 +185,8 @@ def github_oauth_callback():
# Find the e-mail address for the user: we will accept any email, but we prefer the primary # Find the e-mail address for the user: we will accept any email, but we prefer the primary
get_email = client.get(github_login.email_endpoint(), params=token_param, get_email = client.get(github_login.email_endpoint(), params=token_param,
headers=v3_media_type) headers=v3_media_type)
if get_email.status_code / 100 != 2:
return render_ologin_error('GitHub')
found_email = None found_email = None
for user_email in get_email.json(): for user_email in get_email.json():
@ -199,10 +211,13 @@ def github_oauth_callback():
@oauthlogin.route('/google/callback/attach', methods=['GET']) @oauthlogin.route('/google/callback/attach', methods=['GET'])
@route_show_if(features.GOOGLE_LOGIN) @route_show_if(features.GOOGLE_LOGIN)
@require_session_login @require_session_login
@oauthlogin_csrf_protect
def google_oauth_attach(): def google_oauth_attach():
code = request.args.get('code') code = request.args.get('code')
token = google_login.exchange_code_for_token(app.config, client, code, token = google_login.exchange_code_for_token(app.config, client, code,
redirect_suffix='/attach', form_encode=True) redirect_suffix='/attach', form_encode=True)
if token is None:
return render_ologin_error('Google')
user_data = get_user(google_login, token) user_data = get_user(google_login, token)
if not user_data or not user_data.get('id', None): if not user_data or not user_data.get('id', None):
@ -236,9 +251,13 @@ def google_oauth_attach():
@oauthlogin.route('/github/callback/attach', methods=['GET']) @oauthlogin.route('/github/callback/attach', methods=['GET'])
@route_show_if(features.GITHUB_LOGIN) @route_show_if(features.GITHUB_LOGIN)
@require_session_login @require_session_login
@oauthlogin_csrf_protect
def github_oauth_attach(): def github_oauth_attach():
code = request.args.get('code') code = request.args.get('code')
token = github_login.exchange_code_for_token(app.config, client, code) token = github_login.exchange_code_for_token(app.config, client, code)
if token is None:
return render_ologin_error('GitHub')
user_data = get_user(github_login, token) user_data = get_user(github_login, token)
if not user_data: if not user_data:
return render_ologin_error('GitHub') return render_ologin_error('GitHub')
@ -276,6 +295,7 @@ def decode_user_jwt(token, oidc_provider):
@oauthlogin.route('/dex/callback', methods=['GET', 'POST']) @oauthlogin.route('/dex/callback', methods=['GET', 'POST'])
@route_show_if(features.DEX_LOGIN) @route_show_if(features.DEX_LOGIN)
@oauthlogin_csrf_protect
def dex_oauth_callback(): def dex_oauth_callback():
error = request.values.get('error', None) error = request.values.get('error', None)
if error: if error:
@ -287,6 +307,8 @@ def dex_oauth_callback():
token = dex_login.exchange_code_for_token(app.config, client, code, client_auth=True, token = dex_login.exchange_code_for_token(app.config, client, code, client_auth=True,
form_encode=True) form_encode=True)
if token is None:
return render_ologin_error(dex_login.public_title)
try: try:
payload = decode_user_jwt(token, dex_login) payload = decode_user_jwt(token, dex_login)
@ -318,11 +340,12 @@ def dex_oauth_callback():
@oauthlogin.route('/dex/callback/attach', methods=['GET', 'POST']) @oauthlogin.route('/dex/callback/attach', methods=['GET', 'POST'])
@route_show_if(features.DEX_LOGIN) @route_show_if(features.DEX_LOGIN)
@require_session_login @require_session_login
@oauthlogin_csrf_protect
def dex_oauth_attach(): def dex_oauth_attach():
code = request.args.get('code') code = request.args.get('code')
token = dex_login.exchange_code_for_token(app.config, client, code, redirect_suffix='/attach', token = dex_login.exchange_code_for_token(app.config, client, code, redirect_suffix='/attach',
client_auth=True, form_encode=True) client_auth=True, form_encode=True)
if not token: if token is None:
return render_ologin_error(dex_login.public_title) return render_ologin_error(dex_login.public_title)
try: try:
@ -346,4 +369,3 @@ def dex_oauth_attach():
return render_ologin_error(dex_login.public_title, err) return render_ologin_error(dex_login.public_title, err)
return redirect(url_for('web.user_view', path=user_obj.username, tab='external')) return redirect(url_for('web.user_view', path=user_obj.username, tab='external'))

View file

@ -494,7 +494,7 @@ def oauth_local_handler():
@web.route('/oauth/denyapp', methods=['POST']) @web.route('/oauth/denyapp', methods=['POST'])
@csrf_protect @csrf_protect()
def deny_application(): def deny_application():
if not current_user.is_authenticated: if not current_user.is_authenticated:
abort(401) abort(401)

View file

@ -215,7 +215,8 @@ quayApp.config(['$routeProvider', '$locationProvider', 'pages', function($routeP
// 404/403 // 404/403
.route('/:catchall', 'error-view') .route('/:catchall', 'error-view')
.route('/:catch/:all', 'error-view') .route('/:catch/:all', 'error-view')
.route('/:catch/:all/:things', 'error-view'); .route('/:catch/:all/:things', 'error-view')
.route('/:catch/:all/:things/:here', 'error-view');
}]); }]);
// Configure compile provider to add additional URL prefixes to the sanitization list. We use // Configure compile provider to add additional URL prefixes to the sanitization list. We use

View file

@ -21,19 +21,21 @@ angular.module('quay').directive('externalLoginButton', function () {
$scope.startSignin = function() { $scope.startSignin = function() {
$scope.signInStarted({'service': $scope.provider}); $scope.signInStarted({'service': $scope.provider});
ApiService.generateExternalLoginToken().then(function(data) {
var url = ExternalLoginService.getLoginUrl($scope.provider, $scope.action || 'login');
url = url + '&state=' + encodeURIComponent(data['token']);
var url = ExternalLoginService.getLoginUrl($scope.provider, $scope.action || 'login'); // Save the redirect URL in a cookie so that we can redirect back after the service returns to us.
var redirectURL = $scope.redirectUrl || window.location.toString();
CookieService.putPermanent('quay.redirectAfterLoad', redirectURL);
// Save the redirect URL in a cookie so that we can redirect back after the service returns to us. // Needed to ensure that UI work done by the started callback is finished before the location
var redirectURL = $scope.redirectUrl || window.location.toString(); // changes.
CookieService.putPermanent('quay.redirectAfterLoad', redirectURL); $scope.signingIn = true;
$timeout(function() {
// Needed to ensure that UI work done by the started callback is finished before the location document.location = url;
// changes. }, 250);
$scope.signingIn = true; }, ApiService.errorDisplay('Could not perform sign in'));
$timeout(function() {
document.location = url;
}, 250);
}; };
} }
}; };

View file

@ -9,14 +9,6 @@ angular.module('quay').factory('ExternalLoginService', ['KeyService', 'Features'
var serviceInfo = externalLoginService.getProvider(service); var serviceInfo = externalLoginService.getProvider(service);
if (!serviceInfo) { return ''; } if (!serviceInfo) { return ''; }
var stateClause = '';
if (Config.MIXPANEL_KEY && window.mixpanel) {
if (mixpanel.get_distinct_id !== undefined) {
stateClause = "&state=" + encodeURIComponent(mixpanel.get_distinct_id());
}
}
var loginUrl = KeyService.getConfiguration(serviceInfo.key, 'AUTHORIZE_ENDPOINT'); var loginUrl = KeyService.getConfiguration(serviceInfo.key, 'AUTHORIZE_ENDPOINT');
var clientId = KeyService.getConfiguration(serviceInfo.key, 'CLIENT_ID'); var clientId = KeyService.getConfiguration(serviceInfo.key, 'CLIENT_ID');
@ -28,8 +20,7 @@ angular.module('quay').factory('ExternalLoginService', ['KeyService', 'Features'
} }
var url = loginUrl + 'client_id=' + clientId + '&scope=' + scope + '&redirect_uri=' + var url = loginUrl + 'client_id=' + clientId + '&scope=' + scope + '&redirect_uri=' +
redirectUri + stateClause; redirectUri;
return url; return url;
}; };

View file

@ -124,9 +124,11 @@ class ApiTestCase(unittest.TestCase):
query[CSRF_TOKEN_KEY] = CSRF_TOKEN query[CSRF_TOKEN_KEY] = CSRF_TOKEN
return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:])) return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:]))
def url_for(self, resource_name, params={}): def url_for(self, resource_name, params=None, skip_csrf=False):
params = params or {}
url = api.url_for(resource_name, **params) url = api.url_for(resource_name, **params)
url = ApiTestCase._add_csrf(url) if not skip_csrf:
url = ApiTestCase._add_csrf(url)
return url return url
def setUp(self): def setUp(self):
@ -211,8 +213,8 @@ class ApiTestCase(unittest.TestCase):
return parsed return parsed
def putJsonResponse(self, resource_name, params={}, data={}, def putJsonResponse(self, resource_name, params={}, data={},
expected_code=200): expected_code=200, skip_csrf=False):
rv = self.app.put(self.url_for(resource_name, params), rv = self.app.put(self.url_for(resource_name, params, skip_csrf),
data=py_json.dumps(data), data=py_json.dumps(data),
headers={"Content-Type": "application/json"}) headers={"Content-Type": "application/json"})
@ -246,15 +248,35 @@ class TestCSRFFailure(ApiTestCase):
self.login(READ_ACCESS_USER) self.login(READ_ACCESS_USER)
# Make sure a simple post call succeeds. # Make sure a simple post call succeeds.
self.putJsonResponse(User, self.putJsonResponse(User, data=dict(password='newpasswordiscool'))
data=dict(password='newpasswordiscool'))
# Change the session's CSRF token. # Change the session's CSRF token.
self.setCsrfToken('someinvalidtoken') self.setCsrfToken('someinvalidtoken')
# Verify that the call now fails. # Verify that the call now fails.
self.putJsonResponse(User, self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403)
data=dict(password='newpasswordiscool'),
def test_csrf_failure_empty_token(self):
self.login(READ_ACCESS_USER)
# Change the session's CSRF token to be empty.
self.setCsrfToken('')
# Verify that the call now fails.
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403)
def test_csrf_failure_missing_token(self):
self.login(READ_ACCESS_USER)
# Make sure a simple post call without a token at all fails.
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True,
expected_code=403)
# Change the session's CSRF token to be empty.
self.setCsrfToken('')
# Verify that the call still fails.
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True,
expected_code=403) expected_code=403)

View file

@ -8,6 +8,7 @@ import base64
from urllib import urlencode from urllib import urlencode
from urlparse import urlparse, urlunparse, parse_qs from urlparse import urlparse, urlunparse, parse_qs
from datetime import datetime, timedelta from datetime import datetime, timedelta
from httmock import urlmatch, HTTMock
import jwt import jwt
@ -22,13 +23,25 @@ from endpoints import keyserver
from endpoints.api import api, api_bp from endpoints.api import api, api_bp
from endpoints.api.user import Signin from endpoints.api.user import Signin
from endpoints.keyserver import jwk_with_kid from endpoints.keyserver import jwk_with_kid
from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME
from endpoints.web import web as web_bp from endpoints.web import web as web_bp
from endpoints.oauthlogin import oauthlogin as oauthlogin_bp
from initdb import setup_database_for_testing, finished_database_for_testing from initdb import setup_database_for_testing, finished_database_for_testing
from test.helpers import assert_action_logged from test.helpers import assert_action_logged
try:
app.register_blueprint(oauthlogin_bp, url_prefix='/oauth')
except ValueError:
# This blueprint was already registered
pass
try: try:
app.register_blueprint(web_bp, url_prefix='') app.register_blueprint(web_bp, url_prefix='')
except ValueError:
# This blueprint was already registered
pass
try:
app.register_blueprint(keyserver.key_server, url_prefix='') app.register_blueprint(keyserver.key_server, url_prefix='')
except ValueError: except ValueError:
# This blueprint was already registered # This blueprint was already registered
@ -69,6 +82,7 @@ class EndpointTestCase(unittest.TestCase):
def setCsrfToken(self, token): def setCsrfToken(self, token):
with self.app.session_transaction() as sess: with self.app.session_transaction() as sess:
sess[CSRF_TOKEN_KEY] = token sess[CSRF_TOKEN_KEY] = token
sess[OAUTH_CSRF_TOKEN_NAME] = 'someoauthtoken'
def getResponse(self, resource_name, expected_code=200, **kwargs): def getResponse(self, resource_name, expected_code=200, **kwargs):
rv = self.app.get(url_for(resource_name, **kwargs)) rv = self.app.get(url_for(resource_name, **kwargs))
@ -108,6 +122,140 @@ class EndpointTestCase(unittest.TestCase):
self.assertEquals(rv.status_code, 200) self.assertEquals(rv.status_code, 200)
class OAuthLoginTestCase(EndpointTestCase):
def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident,
new_username):
# Test callback.
created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username)
# Delete the created user.
model.user.delete_user(created, [])
# Test attach.
self.login('devtable', 'password')
self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable')
def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username):
# No CSRF.
self.getResponse('oauthlogin.' + endpoint_name, expected_code=403)
# Invalid CSRF.
self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403)
# Valid CSRF, invalid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='invalidcode', expected_code=400)
# Valid CSRF, valid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='somecode', expected_code=302)
# Ensure the user was added/modified.
found_user = model.user.get_user(username)
self.assertIsNotNone(found_user)
federated_login = model.user.lookup_federated_login(found_user, service_name)
self.assertIsNotNone(federated_login)
self.assertEquals(federated_login.service_ident, service_ident)
return found_user
def test_google_oauth(self):
@urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token')
def account_handler(_, request):
if request.body.find("code=somecode") > 0:
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo')
def user_handler(_, __):
content = {
'id': 'someid',
'email': 'someemail@example.com',
'verified_email': True,
}
return py_json.dumps(content)
with HTTMock(account_handler, user_handler):
self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google',
'someid', 'someemail')
def test_github_oauth(self):
@urlmatch(netloc=r'github.com', path='/login/oauth/access_token')
def account_handler(url, _):
if url.query.find("code=somecode") > 0:
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'github.com', path='/api/v3/user')
def user_handler(_, __):
content = {
'id': 'someid',
'login': 'someusername'
}
return py_json.dumps(content)
@urlmatch(netloc=r'github.com', path='/api/v3/user/emails')
def email_handler(_, __):
content = [{
'email': 'someemail@example.com',
'verified': True,
'primary': True,
}]
return py_json.dumps(content)
with HTTMock(account_handler, email_handler, user_handler):
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
'someid', 'someusername')
def test_dex_oauth(self):
# TODO(jschorr): Add tests for invalid and expired keys.
# Generate a public/private key pair for the OIDC transaction.
private_key = RSA.generate(2048)
jwk = RSAKey(key=private_key.publickey()).serialize()
token = jwt.encode({
'iss': 'https://oidcserver/',
'aud': 'someclientid',
'sub': 'someid',
'exp': int(time.time()) + 60,
'iat': int(time.time()),
'nbf': int(time.time()),
'email': 'someemail@example.com',
'email_verified': True,
}, private_key.exportKey('PEM'), 'RS256')
@urlmatch(netloc=r'oidcserver', path='/.well-known/openid-configuration')
def wellknown_handler(url, _):
return py_json.dumps({
'authorization_endpoint': 'http://oidcserver/auth',
'token_endpoint': 'http://oidcserver/token',
'jwks_uri': 'http://oidcserver/keys',
})
@urlmatch(netloc=r'oidcserver', path='/token')
def account_handler(url, request):
if request.body.find("code=somecode") > 0:
return py_json.dumps({
'access_token': token,
})
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'oidcserver', path='/keys')
def keys_handler(_, __):
return py_json.dumps({
"keys": [jwk],
})
with HTTMock(wellknown_handler, account_handler, keys_handler):
self.invoke_oauth_tests('dex_oauth_callback', 'dex_oauth_attach', 'dex',
'someid', 'someemail')
class WebEndpointTestCase(EndpointTestCase): class WebEndpointTestCase(EndpointTestCase):
def test_index(self): def test_index(self):
self.getResponse('web.index') self.getResponse('web.index')

View file

@ -75,3 +75,12 @@ class TestConfig(DefaultConfig):
INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem' INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem'
PROMETHEUS_AGGREGATOR_URL = None PROMETHEUS_AGGREGATOR_URL = None
FEATURE_GITHUB_LOGIN = True
FEATURE_GOOGLE_LOGIN = True
FEATURE_DEX_LOGIN = True
DEX_LOGIN_CONFIG = {
'CLIENT_ID': 'someclientid',
'OIDC_SERVER': 'https://oidcserver/',
}

View file

@ -4,6 +4,11 @@ import logging
import time import time
from cachetools import TTLCache from cachetools import TTLCache
from cachetools.func import lru_cache
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key
from jwkest.jwk import KEYS from jwkest.jwk import KEYS
from util import slash_join from util import slash_join
@ -64,12 +69,14 @@ class OAuthConfig(object):
else: else:
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth) get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
if get_access_token.status_code / 100 != 2:
return None
json_data = get_access_token.json() json_data = get_access_token.json()
if not json_data: if not json_data:
return '' return None
token = json_data.get('access_token', '') return json_data.get('access_token', None)
return token
class GithubOAuthConfig(OAuthConfig): class GithubOAuthConfig(OAuthConfig):
@ -265,11 +272,15 @@ class OIDCConfig(OAuthConfig):
super(OIDCConfig, self).__init__(config, key_name) super(OIDCConfig, self).__init__(config, key_name)
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key) self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
self._oidc_config = {} self._config = config
self._http_client = config['HTTPCLIENT'] self._http_client = config['HTTPCLIENT']
@lru_cache(maxsize=1)
def _oidc_config(self):
if self.config.get('OIDC_SERVER'): if self.config.get('OIDC_SERVER'):
self._load_via_discovery(config.get('DEBUGGING', False)) return self._load_via_discovery(self._config.get('DEBUGGING', False))
else:
return {}
def _load_via_discovery(self, is_debugging): def _load_via_discovery(self, is_debugging):
oidc_server = self.config['OIDC_SERVER'] oidc_server = self.config['OIDC_SERVER']
@ -283,16 +294,16 @@ class OIDCConfig(OAuthConfig):
raise Exception("Could not load OIDC discovery information") raise Exception("Could not load OIDC discovery information")
try: try:
self._oidc_config = json.loads(discovery.text) return json.loads(discovery.text)
except ValueError: except ValueError:
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
raise Exception("Could not parse OIDC discovery information") raise Exception("Could not parse OIDC discovery information")
def authorize_endpoint(self): def authorize_endpoint(self):
return self._oidc_config.get('authorization_endpoint', '') + '?' return self._oidc_config().get('authorization_endpoint', '') + '?'
def token_endpoint(self): def token_endpoint(self):
return self._oidc_config.get('token_endpoint') return self._oidc_config().get('token_endpoint')
def user_endpoint(self): def user_endpoint(self):
return None return None
@ -322,9 +333,9 @@ class OIDCConfig(OAuthConfig):
# a random key chose to be stored in the cache, and could be anything. # a random key chose to be stored in the cache, and could be anything.
return self._public_key_cache[None] return self._public_key_cache[None]
def _get_public_key(self): def _get_public_key(self, _):
""" Retrieves the public key for this handler. """ """ Retrieves the public key for this handler. """
keys_url = self._oidc_config['jwks_uri'] keys_url = self._oidc_config()['jwks_uri']
keys = KEYS() keys = KEYS()
keys.load_from_url(keys_url) keys.load_from_url(keys_url)
@ -334,7 +345,10 @@ class OIDCConfig(OAuthConfig):
rsa_key = list(keys)[0] rsa_key = list(keys)[0]
rsa_key.deserialize() rsa_key.deserialize()
return rsa_key.key.exportKey('PEM')
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
# issues.
return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())
class DexOAuthConfig(OIDCConfig): class DexOAuthConfig(OIDCConfig):