Merge pull request #2224 from coreos-inc/oauth-state
Have Quay always use an OAuth-specific CSRF token
This commit is contained in:
commit
648fed769b
12 changed files with 310 additions and 74 deletions
|
@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
|
||||||
api_bp = Blueprint('api', __name__)
|
api_bp = Blueprint('api', __name__)
|
||||||
api = Api()
|
api = Api()
|
||||||
api.init_app(api_bp)
|
api.init_app(api_bp)
|
||||||
api.decorators = [csrf_protect,
|
api.decorators = [csrf_protect(),
|
||||||
crossdomain(origin='*', headers=['Authorization', 'Content-Type']),
|
crossdomain(origin='*', headers=['Authorization', 'Content-Type']),
|
||||||
process_oauth, time_decorator(api_bp.name, metric_queue)]
|
process_oauth, time_decorator(api_bp.name, metric_queue)]
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ from endpoints.api import (ApiResource, nickname, resource, validate_json_reques
|
||||||
from endpoints.exception import NotFound, InvalidToken
|
from endpoints.exception import NotFound, InvalidToken
|
||||||
from endpoints.api.subscribe import subscribe
|
from endpoints.api.subscribe import subscribe
|
||||||
from endpoints.common import common_login
|
from endpoints.common import common_login
|
||||||
|
from endpoints.csrf import generate_csrf_token, OAUTH_CSRF_TOKEN_NAME
|
||||||
from endpoints.decorators import anon_allowed
|
from endpoints.decorators import anon_allowed
|
||||||
from util.useremails import (send_confirmation_email, send_recovery_email, send_change_email,
|
from util.useremails import (send_confirmation_email, send_recovery_email, send_change_email,
|
||||||
send_password_changed, send_org_recovery_email)
|
send_password_changed, send_org_recovery_email)
|
||||||
|
@ -673,6 +674,15 @@ class Signout(ApiResource):
|
||||||
return {'success': True}
|
return {'success': True}
|
||||||
|
|
||||||
|
|
||||||
|
@resource('/v1/externaltoken')
|
||||||
|
@internal_only
|
||||||
|
class GenerateExternalToken(ApiResource):
|
||||||
|
""" Resource for generating a token for external login. """
|
||||||
|
@nickname('generateExternalLoginToken')
|
||||||
|
def post(self):
|
||||||
|
""" Generates a CSRF token explicitly for OIDC/OAuth-associated login. """
|
||||||
|
return {'token': generate_csrf_token(OAUTH_CSRF_TOKEN_NAME)}
|
||||||
|
|
||||||
|
|
||||||
@resource('/v1/detachexternal/<servicename>')
|
@resource('/v1/detachexternal/<servicename>')
|
||||||
@show_if(features.DIRECT_LOGIN)
|
@show_if(features.DIRECT_LOGIN)
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import base64
|
import base64
|
||||||
|
import hmac
|
||||||
|
|
||||||
from flask import session, request
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from flask import session, request
|
||||||
|
|
||||||
from app import app
|
from app import app
|
||||||
from auth.auth_context import get_validated_oauth_token
|
from auth.auth_context import get_validated_oauth_token
|
||||||
|
@ -12,31 +13,47 @@ from util.http import abort
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
OAUTH_CSRF_TOKEN_NAME = '_oauth_csrf_token'
|
||||||
|
_QUAY_CSRF_TOKEN_NAME = '_csrf_token'
|
||||||
|
|
||||||
def generate_csrf_token():
|
def generate_csrf_token(session_token_name=_QUAY_CSRF_TOKEN_NAME):
|
||||||
if '_csrf_token' not in session:
|
""" If not present in the session, generates a new CSRF token with the given name
|
||||||
session['_csrf_token'] = base64.b64encode(os.urandom(48))
|
and places it into the session. Returns the generated token.
|
||||||
|
"""
|
||||||
|
if session_token_name not in session:
|
||||||
|
session[session_token_name] = base64.b64encode(os.urandom(48))
|
||||||
|
|
||||||
return session['_csrf_token']
|
return session[session_token_name]
|
||||||
|
|
||||||
def verify_csrf():
|
|
||||||
token = session.get('_csrf_token', None)
|
|
||||||
found_token = request.values.get('_csrf_token', None)
|
|
||||||
|
|
||||||
if not token or token != found_token:
|
def verify_csrf(session_token_name=_QUAY_CSRF_TOKEN_NAME,
|
||||||
msg = 'CSRF Failure. Session token was %s and request token was %s'
|
request_token_name=_QUAY_CSRF_TOKEN_NAME):
|
||||||
logger.error(msg, token, found_token)
|
""" Verifies that the CSRF token with the given name is found in the session and
|
||||||
|
that the matching token is found in the request args or values.
|
||||||
|
"""
|
||||||
|
token = str(session.get(session_token_name, ''))
|
||||||
|
found_token = str(request.values.get(request_token_name, ''))
|
||||||
|
|
||||||
|
if not token or not found_token or not hmac.compare_digest(token, found_token):
|
||||||
|
msg = 'CSRF Failure. Session token (%s) was %s and request token (%s) was %s'
|
||||||
|
logger.error(msg, session_token_name, token, request_token_name, found_token)
|
||||||
abort(403, message='CSRF token was invalid or missing.')
|
abort(403, message='CSRF token was invalid or missing.')
|
||||||
|
|
||||||
def csrf_protect(func):
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
oauth_token = get_validated_oauth_token()
|
|
||||||
if oauth_token is None and request.method != "GET" and request.method != "HEAD":
|
|
||||||
verify_csrf()
|
|
||||||
|
|
||||||
return func(*args, **kwargs)
|
def csrf_protect(session_token_name=_QUAY_CSRF_TOKEN_NAME,
|
||||||
return wrapper
|
request_token_name=_QUAY_CSRF_TOKEN_NAME,
|
||||||
|
all_methods=False):
|
||||||
|
def inner(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
oauth_token = get_validated_oauth_token()
|
||||||
|
if oauth_token is None:
|
||||||
|
if all_methods or (request.method != "GET" and request.method != "HEAD"):
|
||||||
|
verify_csrf(session_token_name, request_token_name)
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
app.jinja_env.globals['csrf_token'] = generate_csrf_token
|
app.jinja_env.globals['csrf_token'] = generate_csrf_token
|
||||||
|
|
|
@ -12,6 +12,7 @@ from auth.process import require_session_login
|
||||||
from data import model
|
from data import model
|
||||||
from endpoints.common import common_login, route_show_if
|
from endpoints.common import common_login, route_show_if
|
||||||
from endpoints.web import index
|
from endpoints.web import index
|
||||||
|
from endpoints.csrf import csrf_protect, OAUTH_CSRF_TOKEN_NAME
|
||||||
from util.security.jwtutil import decode, InvalidTokenError
|
from util.security.jwtutil import decode, InvalidTokenError
|
||||||
from util.validation import generate_valid_usernames
|
from util.validation import generate_valid_usernames
|
||||||
|
|
||||||
|
@ -19,6 +20,7 @@ logger = logging.getLogger(__name__)
|
||||||
client = app.config['HTTPCLIENT']
|
client = app.config['HTTPCLIENT']
|
||||||
oauthlogin = Blueprint('oauthlogin', __name__)
|
oauthlogin = Blueprint('oauthlogin', __name__)
|
||||||
|
|
||||||
|
oauthlogin_csrf_protect = csrf_protect(OAUTH_CSRF_TOKEN_NAME, 'state', all_methods=True)
|
||||||
|
|
||||||
def render_ologin_error(service_name, error_message=None, register_redirect=False):
|
def render_ologin_error(service_name, error_message=None, register_redirect=False):
|
||||||
user_creation = bool(features.USER_CREATION and features.DIRECT_LOGIN)
|
user_creation = bool(features.USER_CREATION and features.DIRECT_LOGIN)
|
||||||
|
@ -30,7 +32,11 @@ def render_ologin_error(service_name, error_message=None, register_redirect=Fals
|
||||||
'user_creation': user_creation,
|
'user_creation': user_creation,
|
||||||
'register_redirect': register_redirect,
|
'register_redirect': register_redirect,
|
||||||
}
|
}
|
||||||
return index('', error_info=error_info)
|
|
||||||
|
resp = index('', error_info=error_info)
|
||||||
|
resp.status_code = 400
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
def get_user(service, token):
|
def get_user(service, token):
|
||||||
token_param = {
|
token_param = {
|
||||||
|
@ -44,7 +50,7 @@ def get_user(service, token):
|
||||||
return got_user.json()
|
return got_user.json()
|
||||||
|
|
||||||
|
|
||||||
def conduct_oauth_login(service, user_id, username, email, metadata={}):
|
def conduct_oauth_login(service, user_id, username, email, metadata=None):
|
||||||
service_name = service.service_name()
|
service_name = service.service_name()
|
||||||
to_login = model.user.verify_federated_login(service_name.lower(), user_id)
|
to_login = model.user.verify_federated_login(service_name.lower(), user_id)
|
||||||
if not to_login:
|
if not to_login:
|
||||||
|
@ -66,17 +72,12 @@ def conduct_oauth_login(service, user_id, username, email, metadata={}):
|
||||||
prompts = model.user.get_default_user_prompts(features)
|
prompts = model.user.get_default_user_prompts(features)
|
||||||
to_login = model.user.create_federated_user(new_username, email, service_name.lower(),
|
to_login = model.user.create_federated_user(new_username, email, service_name.lower(),
|
||||||
user_id, set_password_notification=True,
|
user_id, set_password_notification=True,
|
||||||
metadata=metadata,
|
metadata=metadata or {},
|
||||||
prompts=prompts)
|
prompts=prompts)
|
||||||
|
|
||||||
# Success, tell analytics
|
# Success, tell analytics
|
||||||
analytics.track(to_login.username, 'register', {'service': service_name.lower()})
|
analytics.track(to_login.username, 'register', {'service': service_name.lower()})
|
||||||
|
|
||||||
state = request.args.get('state', None)
|
|
||||||
if state:
|
|
||||||
logger.debug('Aliasing with state: %s', state)
|
|
||||||
analytics.alias(to_login.username, state)
|
|
||||||
|
|
||||||
except model.InvalidEmailAddressException:
|
except model.InvalidEmailAddressException:
|
||||||
message = "The e-mail address %s is already associated " % (email, )
|
message = "The e-mail address %s is already associated " % (email, )
|
||||||
message = message + "with an existing %s account." % (app.config['REGISTRY_TITLE_SHORT'], )
|
message = message + "with an existing %s account." % (app.config['REGISTRY_TITLE_SHORT'], )
|
||||||
|
@ -96,6 +97,7 @@ def conduct_oauth_login(service, user_id, username, email, metadata={}):
|
||||||
|
|
||||||
return render_ologin_error(service_name)
|
return render_ologin_error(service_name)
|
||||||
|
|
||||||
|
|
||||||
def get_email_username(user_data):
|
def get_email_username(user_data):
|
||||||
username = user_data['email']
|
username = user_data['email']
|
||||||
at = username.find('@')
|
at = username.find('@')
|
||||||
|
@ -107,6 +109,7 @@ def get_email_username(user_data):
|
||||||
|
|
||||||
@oauthlogin.route('/google/callback', methods=['GET'])
|
@oauthlogin.route('/google/callback', methods=['GET'])
|
||||||
@route_show_if(features.GOOGLE_LOGIN)
|
@route_show_if(features.GOOGLE_LOGIN)
|
||||||
|
@oauthlogin_csrf_protect
|
||||||
def google_oauth_callback():
|
def google_oauth_callback():
|
||||||
error = request.args.get('error', None)
|
error = request.args.get('error', None)
|
||||||
if error:
|
if error:
|
||||||
|
@ -114,6 +117,9 @@ def google_oauth_callback():
|
||||||
|
|
||||||
code = request.args.get('code')
|
code = request.args.get('code')
|
||||||
token = google_login.exchange_code_for_token(app.config, client, code, form_encode=True)
|
token = google_login.exchange_code_for_token(app.config, client, code, form_encode=True)
|
||||||
|
if token is None:
|
||||||
|
return render_ologin_error('Google')
|
||||||
|
|
||||||
user_data = get_user(google_login, token)
|
user_data = get_user(google_login, token)
|
||||||
if not user_data or not user_data.get('id', None) or not user_data.get('email', None):
|
if not user_data or not user_data.get('id', None) or not user_data.get('email', None):
|
||||||
return render_ologin_error('Google')
|
return render_ologin_error('Google')
|
||||||
|
@ -136,6 +142,7 @@ def google_oauth_callback():
|
||||||
|
|
||||||
@oauthlogin.route('/github/callback', methods=['GET'])
|
@oauthlogin.route('/github/callback', methods=['GET'])
|
||||||
@route_show_if(features.GITHUB_LOGIN)
|
@route_show_if(features.GITHUB_LOGIN)
|
||||||
|
@oauthlogin_csrf_protect
|
||||||
def github_oauth_callback():
|
def github_oauth_callback():
|
||||||
error = request.args.get('error', None)
|
error = request.args.get('error', None)
|
||||||
if error:
|
if error:
|
||||||
|
@ -144,10 +151,12 @@ def github_oauth_callback():
|
||||||
# Exchange the OAuth code.
|
# Exchange the OAuth code.
|
||||||
code = request.args.get('code')
|
code = request.args.get('code')
|
||||||
token = github_login.exchange_code_for_token(app.config, client, code)
|
token = github_login.exchange_code_for_token(app.config, client, code)
|
||||||
|
if token is None:
|
||||||
|
return render_ologin_error('GitHub')
|
||||||
|
|
||||||
# Retrieve the user's information.
|
# Retrieve the user's information.
|
||||||
user_data = get_user(github_login, token)
|
user_data = get_user(github_login, token)
|
||||||
if not user_data or not 'login' in user_data:
|
if not user_data or 'login' not in user_data:
|
||||||
return render_ologin_error('GitHub')
|
return render_ologin_error('GitHub')
|
||||||
|
|
||||||
username = user_data['login']
|
username = user_data['login']
|
||||||
|
@ -167,7 +176,8 @@ def github_oauth_callback():
|
||||||
headers={'Accept': 'application/vnd.github.moondragon+json'})
|
headers={'Accept': 'application/vnd.github.moondragon+json'})
|
||||||
|
|
||||||
organizations = set([org.get('login').lower() for org in get_orgs.json()])
|
organizations = set([org.get('login').lower() for org in get_orgs.json()])
|
||||||
if not (organizations & set(github_login.allowed_organizations())):
|
matching_organizations = organizations & set(github_login.allowed_organizations())
|
||||||
|
if not matching_organizations:
|
||||||
err = """You are not a member of an allowed GitHub organization.
|
err = """You are not a member of an allowed GitHub organization.
|
||||||
Please contact your system administrator if you believe this is in error."""
|
Please contact your system administrator if you believe this is in error."""
|
||||||
return render_ologin_error('GitHub', err)
|
return render_ologin_error('GitHub', err)
|
||||||
|
@ -175,6 +185,8 @@ def github_oauth_callback():
|
||||||
# Find the e-mail address for the user: we will accept any email, but we prefer the primary
|
# Find the e-mail address for the user: we will accept any email, but we prefer the primary
|
||||||
get_email = client.get(github_login.email_endpoint(), params=token_param,
|
get_email = client.get(github_login.email_endpoint(), params=token_param,
|
||||||
headers=v3_media_type)
|
headers=v3_media_type)
|
||||||
|
if get_email.status_code / 100 != 2:
|
||||||
|
return render_ologin_error('GitHub')
|
||||||
|
|
||||||
found_email = None
|
found_email = None
|
||||||
for user_email in get_email.json():
|
for user_email in get_email.json():
|
||||||
|
@ -199,10 +211,13 @@ def github_oauth_callback():
|
||||||
@oauthlogin.route('/google/callback/attach', methods=['GET'])
|
@oauthlogin.route('/google/callback/attach', methods=['GET'])
|
||||||
@route_show_if(features.GOOGLE_LOGIN)
|
@route_show_if(features.GOOGLE_LOGIN)
|
||||||
@require_session_login
|
@require_session_login
|
||||||
|
@oauthlogin_csrf_protect
|
||||||
def google_oauth_attach():
|
def google_oauth_attach():
|
||||||
code = request.args.get('code')
|
code = request.args.get('code')
|
||||||
token = google_login.exchange_code_for_token(app.config, client, code,
|
token = google_login.exchange_code_for_token(app.config, client, code,
|
||||||
redirect_suffix='/attach', form_encode=True)
|
redirect_suffix='/attach', form_encode=True)
|
||||||
|
if token is None:
|
||||||
|
return render_ologin_error('Google')
|
||||||
|
|
||||||
user_data = get_user(google_login, token)
|
user_data = get_user(google_login, token)
|
||||||
if not user_data or not user_data.get('id', None):
|
if not user_data or not user_data.get('id', None):
|
||||||
|
@ -236,9 +251,13 @@ def google_oauth_attach():
|
||||||
@oauthlogin.route('/github/callback/attach', methods=['GET'])
|
@oauthlogin.route('/github/callback/attach', methods=['GET'])
|
||||||
@route_show_if(features.GITHUB_LOGIN)
|
@route_show_if(features.GITHUB_LOGIN)
|
||||||
@require_session_login
|
@require_session_login
|
||||||
|
@oauthlogin_csrf_protect
|
||||||
def github_oauth_attach():
|
def github_oauth_attach():
|
||||||
code = request.args.get('code')
|
code = request.args.get('code')
|
||||||
token = github_login.exchange_code_for_token(app.config, client, code)
|
token = github_login.exchange_code_for_token(app.config, client, code)
|
||||||
|
if token is None:
|
||||||
|
return render_ologin_error('GitHub')
|
||||||
|
|
||||||
user_data = get_user(github_login, token)
|
user_data = get_user(github_login, token)
|
||||||
if not user_data:
|
if not user_data:
|
||||||
return render_ologin_error('GitHub')
|
return render_ologin_error('GitHub')
|
||||||
|
@ -276,6 +295,7 @@ def decode_user_jwt(token, oidc_provider):
|
||||||
|
|
||||||
@oauthlogin.route('/dex/callback', methods=['GET', 'POST'])
|
@oauthlogin.route('/dex/callback', methods=['GET', 'POST'])
|
||||||
@route_show_if(features.DEX_LOGIN)
|
@route_show_if(features.DEX_LOGIN)
|
||||||
|
@oauthlogin_csrf_protect
|
||||||
def dex_oauth_callback():
|
def dex_oauth_callback():
|
||||||
error = request.values.get('error', None)
|
error = request.values.get('error', None)
|
||||||
if error:
|
if error:
|
||||||
|
@ -287,6 +307,8 @@ def dex_oauth_callback():
|
||||||
|
|
||||||
token = dex_login.exchange_code_for_token(app.config, client, code, client_auth=True,
|
token = dex_login.exchange_code_for_token(app.config, client, code, client_auth=True,
|
||||||
form_encode=True)
|
form_encode=True)
|
||||||
|
if token is None:
|
||||||
|
return render_ologin_error(dex_login.public_title)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = decode_user_jwt(token, dex_login)
|
payload = decode_user_jwt(token, dex_login)
|
||||||
|
@ -318,11 +340,12 @@ def dex_oauth_callback():
|
||||||
@oauthlogin.route('/dex/callback/attach', methods=['GET', 'POST'])
|
@oauthlogin.route('/dex/callback/attach', methods=['GET', 'POST'])
|
||||||
@route_show_if(features.DEX_LOGIN)
|
@route_show_if(features.DEX_LOGIN)
|
||||||
@require_session_login
|
@require_session_login
|
||||||
|
@oauthlogin_csrf_protect
|
||||||
def dex_oauth_attach():
|
def dex_oauth_attach():
|
||||||
code = request.args.get('code')
|
code = request.args.get('code')
|
||||||
token = dex_login.exchange_code_for_token(app.config, client, code, redirect_suffix='/attach',
|
token = dex_login.exchange_code_for_token(app.config, client, code, redirect_suffix='/attach',
|
||||||
client_auth=True, form_encode=True)
|
client_auth=True, form_encode=True)
|
||||||
if not token:
|
if token is None:
|
||||||
return render_ologin_error(dex_login.public_title)
|
return render_ologin_error(dex_login.public_title)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -346,4 +369,3 @@ def dex_oauth_attach():
|
||||||
return render_ologin_error(dex_login.public_title, err)
|
return render_ologin_error(dex_login.public_title, err)
|
||||||
|
|
||||||
return redirect(url_for('web.user_view', path=user_obj.username, tab='external'))
|
return redirect(url_for('web.user_view', path=user_obj.username, tab='external'))
|
||||||
|
|
||||||
|
|
|
@ -494,7 +494,7 @@ def oauth_local_handler():
|
||||||
|
|
||||||
|
|
||||||
@web.route('/oauth/denyapp', methods=['POST'])
|
@web.route('/oauth/denyapp', methods=['POST'])
|
||||||
@csrf_protect
|
@csrf_protect()
|
||||||
def deny_application():
|
def deny_application():
|
||||||
if not current_user.is_authenticated:
|
if not current_user.is_authenticated:
|
||||||
abort(401)
|
abort(401)
|
||||||
|
|
|
@ -215,7 +215,8 @@ quayApp.config(['$routeProvider', '$locationProvider', 'pages', function($routeP
|
||||||
// 404/403
|
// 404/403
|
||||||
.route('/:catchall', 'error-view')
|
.route('/:catchall', 'error-view')
|
||||||
.route('/:catch/:all', 'error-view')
|
.route('/:catch/:all', 'error-view')
|
||||||
.route('/:catch/:all/:things', 'error-view');
|
.route('/:catch/:all/:things', 'error-view')
|
||||||
|
.route('/:catch/:all/:things/:here', 'error-view');
|
||||||
}]);
|
}]);
|
||||||
|
|
||||||
// Configure compile provider to add additional URL prefixes to the sanitization list. We use
|
// Configure compile provider to add additional URL prefixes to the sanitization list. We use
|
||||||
|
|
|
@ -21,19 +21,21 @@ angular.module('quay').directive('externalLoginButton', function () {
|
||||||
|
|
||||||
$scope.startSignin = function() {
|
$scope.startSignin = function() {
|
||||||
$scope.signInStarted({'service': $scope.provider});
|
$scope.signInStarted({'service': $scope.provider});
|
||||||
|
ApiService.generateExternalLoginToken().then(function(data) {
|
||||||
|
var url = ExternalLoginService.getLoginUrl($scope.provider, $scope.action || 'login');
|
||||||
|
url = url + '&state=' + encodeURIComponent(data['token']);
|
||||||
|
|
||||||
var url = ExternalLoginService.getLoginUrl($scope.provider, $scope.action || 'login');
|
// Save the redirect URL in a cookie so that we can redirect back after the service returns to us.
|
||||||
|
var redirectURL = $scope.redirectUrl || window.location.toString();
|
||||||
|
CookieService.putPermanent('quay.redirectAfterLoad', redirectURL);
|
||||||
|
|
||||||
// Save the redirect URL in a cookie so that we can redirect back after the service returns to us.
|
// Needed to ensure that UI work done by the started callback is finished before the location
|
||||||
var redirectURL = $scope.redirectUrl || window.location.toString();
|
// changes.
|
||||||
CookieService.putPermanent('quay.redirectAfterLoad', redirectURL);
|
$scope.signingIn = true;
|
||||||
|
$timeout(function() {
|
||||||
// Needed to ensure that UI work done by the started callback is finished before the location
|
document.location = url;
|
||||||
// changes.
|
}, 250);
|
||||||
$scope.signingIn = true;
|
}, ApiService.errorDisplay('Could not perform sign in'));
|
||||||
$timeout(function() {
|
|
||||||
document.location = url;
|
|
||||||
}, 250);
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -9,14 +9,6 @@ angular.module('quay').factory('ExternalLoginService', ['KeyService', 'Features'
|
||||||
var serviceInfo = externalLoginService.getProvider(service);
|
var serviceInfo = externalLoginService.getProvider(service);
|
||||||
if (!serviceInfo) { return ''; }
|
if (!serviceInfo) { return ''; }
|
||||||
|
|
||||||
var stateClause = '';
|
|
||||||
|
|
||||||
if (Config.MIXPANEL_KEY && window.mixpanel) {
|
|
||||||
if (mixpanel.get_distinct_id !== undefined) {
|
|
||||||
stateClause = "&state=" + encodeURIComponent(mixpanel.get_distinct_id());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var loginUrl = KeyService.getConfiguration(serviceInfo.key, 'AUTHORIZE_ENDPOINT');
|
var loginUrl = KeyService.getConfiguration(serviceInfo.key, 'AUTHORIZE_ENDPOINT');
|
||||||
var clientId = KeyService.getConfiguration(serviceInfo.key, 'CLIENT_ID');
|
var clientId = KeyService.getConfiguration(serviceInfo.key, 'CLIENT_ID');
|
||||||
|
|
||||||
|
@ -28,8 +20,7 @@ angular.module('quay').factory('ExternalLoginService', ['KeyService', 'Features'
|
||||||
}
|
}
|
||||||
|
|
||||||
var url = loginUrl + 'client_id=' + clientId + '&scope=' + scope + '&redirect_uri=' +
|
var url = loginUrl + 'client_id=' + clientId + '&scope=' + scope + '&redirect_uri=' +
|
||||||
redirectUri + stateClause;
|
redirectUri;
|
||||||
|
|
||||||
return url;
|
return url;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -124,9 +124,11 @@ class ApiTestCase(unittest.TestCase):
|
||||||
query[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
query[CSRF_TOKEN_KEY] = CSRF_TOKEN
|
||||||
return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:]))
|
return urlunparse(list(parts[0:4]) + [urlencode(query)] + list(parts[5:]))
|
||||||
|
|
||||||
def url_for(self, resource_name, params={}):
|
def url_for(self, resource_name, params=None, skip_csrf=False):
|
||||||
|
params = params or {}
|
||||||
url = api.url_for(resource_name, **params)
|
url = api.url_for(resource_name, **params)
|
||||||
url = ApiTestCase._add_csrf(url)
|
if not skip_csrf:
|
||||||
|
url = ApiTestCase._add_csrf(url)
|
||||||
return url
|
return url
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -211,8 +213,8 @@ class ApiTestCase(unittest.TestCase):
|
||||||
return parsed
|
return parsed
|
||||||
|
|
||||||
def putJsonResponse(self, resource_name, params={}, data={},
|
def putJsonResponse(self, resource_name, params={}, data={},
|
||||||
expected_code=200):
|
expected_code=200, skip_csrf=False):
|
||||||
rv = self.app.put(self.url_for(resource_name, params),
|
rv = self.app.put(self.url_for(resource_name, params, skip_csrf),
|
||||||
data=py_json.dumps(data),
|
data=py_json.dumps(data),
|
||||||
headers={"Content-Type": "application/json"})
|
headers={"Content-Type": "application/json"})
|
||||||
|
|
||||||
|
@ -246,15 +248,35 @@ class TestCSRFFailure(ApiTestCase):
|
||||||
self.login(READ_ACCESS_USER)
|
self.login(READ_ACCESS_USER)
|
||||||
|
|
||||||
# Make sure a simple post call succeeds.
|
# Make sure a simple post call succeeds.
|
||||||
self.putJsonResponse(User,
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'))
|
||||||
data=dict(password='newpasswordiscool'))
|
|
||||||
|
|
||||||
# Change the session's CSRF token.
|
# Change the session's CSRF token.
|
||||||
self.setCsrfToken('someinvalidtoken')
|
self.setCsrfToken('someinvalidtoken')
|
||||||
|
|
||||||
# Verify that the call now fails.
|
# Verify that the call now fails.
|
||||||
self.putJsonResponse(User,
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403)
|
||||||
data=dict(password='newpasswordiscool'),
|
|
||||||
|
def test_csrf_failure_empty_token(self):
|
||||||
|
self.login(READ_ACCESS_USER)
|
||||||
|
|
||||||
|
# Change the session's CSRF token to be empty.
|
||||||
|
self.setCsrfToken('')
|
||||||
|
|
||||||
|
# Verify that the call now fails.
|
||||||
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), expected_code=403)
|
||||||
|
|
||||||
|
def test_csrf_failure_missing_token(self):
|
||||||
|
self.login(READ_ACCESS_USER)
|
||||||
|
|
||||||
|
# Make sure a simple post call without a token at all fails.
|
||||||
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True,
|
||||||
|
expected_code=403)
|
||||||
|
|
||||||
|
# Change the session's CSRF token to be empty.
|
||||||
|
self.setCsrfToken('')
|
||||||
|
|
||||||
|
# Verify that the call still fails.
|
||||||
|
self.putJsonResponse(User, data=dict(password='newpasswordiscool'), skip_csrf=True,
|
||||||
expected_code=403)
|
expected_code=403)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import base64
|
||||||
from urllib import urlencode
|
from urllib import urlencode
|
||||||
from urlparse import urlparse, urlunparse, parse_qs
|
from urlparse import urlparse, urlunparse, parse_qs
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from httmock import urlmatch, HTTMock
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
|
@ -22,13 +23,25 @@ from endpoints import keyserver
|
||||||
from endpoints.api import api, api_bp
|
from endpoints.api import api, api_bp
|
||||||
from endpoints.api.user import Signin
|
from endpoints.api.user import Signin
|
||||||
from endpoints.keyserver import jwk_with_kid
|
from endpoints.keyserver import jwk_with_kid
|
||||||
|
from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME
|
||||||
from endpoints.web import web as web_bp
|
from endpoints.web import web as web_bp
|
||||||
|
from endpoints.oauthlogin import oauthlogin as oauthlogin_bp
|
||||||
from initdb import setup_database_for_testing, finished_database_for_testing
|
from initdb import setup_database_for_testing, finished_database_for_testing
|
||||||
from test.helpers import assert_action_logged
|
from test.helpers import assert_action_logged
|
||||||
|
|
||||||
|
try:
|
||||||
|
app.register_blueprint(oauthlogin_bp, url_prefix='/oauth')
|
||||||
|
except ValueError:
|
||||||
|
# This blueprint was already registered
|
||||||
|
pass
|
||||||
|
|
||||||
try:
|
try:
|
||||||
app.register_blueprint(web_bp, url_prefix='')
|
app.register_blueprint(web_bp, url_prefix='')
|
||||||
|
except ValueError:
|
||||||
|
# This blueprint was already registered
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
app.register_blueprint(keyserver.key_server, url_prefix='')
|
app.register_blueprint(keyserver.key_server, url_prefix='')
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# This blueprint was already registered
|
# This blueprint was already registered
|
||||||
|
@ -69,6 +82,7 @@ class EndpointTestCase(unittest.TestCase):
|
||||||
def setCsrfToken(self, token):
|
def setCsrfToken(self, token):
|
||||||
with self.app.session_transaction() as sess:
|
with self.app.session_transaction() as sess:
|
||||||
sess[CSRF_TOKEN_KEY] = token
|
sess[CSRF_TOKEN_KEY] = token
|
||||||
|
sess[OAUTH_CSRF_TOKEN_NAME] = 'someoauthtoken'
|
||||||
|
|
||||||
def getResponse(self, resource_name, expected_code=200, **kwargs):
|
def getResponse(self, resource_name, expected_code=200, **kwargs):
|
||||||
rv = self.app.get(url_for(resource_name, **kwargs))
|
rv = self.app.get(url_for(resource_name, **kwargs))
|
||||||
|
@ -108,6 +122,140 @@ class EndpointTestCase(unittest.TestCase):
|
||||||
self.assertEquals(rv.status_code, 200)
|
self.assertEquals(rv.status_code, 200)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthLoginTestCase(EndpointTestCase):
|
||||||
|
def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident,
|
||||||
|
new_username):
|
||||||
|
# Test callback.
|
||||||
|
created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username)
|
||||||
|
|
||||||
|
# Delete the created user.
|
||||||
|
model.user.delete_user(created, [])
|
||||||
|
|
||||||
|
# Test attach.
|
||||||
|
self.login('devtable', 'password')
|
||||||
|
self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable')
|
||||||
|
|
||||||
|
def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username):
|
||||||
|
# No CSRF.
|
||||||
|
self.getResponse('oauthlogin.' + endpoint_name, expected_code=403)
|
||||||
|
|
||||||
|
# Invalid CSRF.
|
||||||
|
self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403)
|
||||||
|
|
||||||
|
# Valid CSRF, invalid code.
|
||||||
|
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
|
||||||
|
code='invalidcode', expected_code=400)
|
||||||
|
|
||||||
|
# Valid CSRF, valid code.
|
||||||
|
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
|
||||||
|
code='somecode', expected_code=302)
|
||||||
|
|
||||||
|
# Ensure the user was added/modified.
|
||||||
|
found_user = model.user.get_user(username)
|
||||||
|
self.assertIsNotNone(found_user)
|
||||||
|
|
||||||
|
federated_login = model.user.lookup_federated_login(found_user, service_name)
|
||||||
|
self.assertIsNotNone(federated_login)
|
||||||
|
self.assertEquals(federated_login.service_ident, service_ident)
|
||||||
|
return found_user
|
||||||
|
|
||||||
|
def test_google_oauth(self):
|
||||||
|
@urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token')
|
||||||
|
def account_handler(_, request):
|
||||||
|
if request.body.find("code=somecode") > 0:
|
||||||
|
content = {'access_token': 'someaccesstoken'}
|
||||||
|
return py_json.dumps(content)
|
||||||
|
else:
|
||||||
|
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
||||||
|
|
||||||
|
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo')
|
||||||
|
def user_handler(_, __):
|
||||||
|
content = {
|
||||||
|
'id': 'someid',
|
||||||
|
'email': 'someemail@example.com',
|
||||||
|
'verified_email': True,
|
||||||
|
}
|
||||||
|
return py_json.dumps(content)
|
||||||
|
|
||||||
|
with HTTMock(account_handler, user_handler):
|
||||||
|
self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google',
|
||||||
|
'someid', 'someemail')
|
||||||
|
|
||||||
|
def test_github_oauth(self):
|
||||||
|
@urlmatch(netloc=r'github.com', path='/login/oauth/access_token')
|
||||||
|
def account_handler(url, _):
|
||||||
|
if url.query.find("code=somecode") > 0:
|
||||||
|
content = {'access_token': 'someaccesstoken'}
|
||||||
|
return py_json.dumps(content)
|
||||||
|
else:
|
||||||
|
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
||||||
|
|
||||||
|
@urlmatch(netloc=r'github.com', path='/api/v3/user')
|
||||||
|
def user_handler(_, __):
|
||||||
|
content = {
|
||||||
|
'id': 'someid',
|
||||||
|
'login': 'someusername'
|
||||||
|
}
|
||||||
|
return py_json.dumps(content)
|
||||||
|
|
||||||
|
@urlmatch(netloc=r'github.com', path='/api/v3/user/emails')
|
||||||
|
def email_handler(_, __):
|
||||||
|
content = [{
|
||||||
|
'email': 'someemail@example.com',
|
||||||
|
'verified': True,
|
||||||
|
'primary': True,
|
||||||
|
}]
|
||||||
|
return py_json.dumps(content)
|
||||||
|
|
||||||
|
with HTTMock(account_handler, email_handler, user_handler):
|
||||||
|
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
|
||||||
|
'someid', 'someusername')
|
||||||
|
|
||||||
|
def test_dex_oauth(self):
|
||||||
|
# TODO(jschorr): Add tests for invalid and expired keys.
|
||||||
|
|
||||||
|
# Generate a public/private key pair for the OIDC transaction.
|
||||||
|
private_key = RSA.generate(2048)
|
||||||
|
jwk = RSAKey(key=private_key.publickey()).serialize()
|
||||||
|
token = jwt.encode({
|
||||||
|
'iss': 'https://oidcserver/',
|
||||||
|
'aud': 'someclientid',
|
||||||
|
'sub': 'someid',
|
||||||
|
'exp': int(time.time()) + 60,
|
||||||
|
'iat': int(time.time()),
|
||||||
|
'nbf': int(time.time()),
|
||||||
|
'email': 'someemail@example.com',
|
||||||
|
'email_verified': True,
|
||||||
|
}, private_key.exportKey('PEM'), 'RS256')
|
||||||
|
|
||||||
|
@urlmatch(netloc=r'oidcserver', path='/.well-known/openid-configuration')
|
||||||
|
def wellknown_handler(url, _):
|
||||||
|
return py_json.dumps({
|
||||||
|
'authorization_endpoint': 'http://oidcserver/auth',
|
||||||
|
'token_endpoint': 'http://oidcserver/token',
|
||||||
|
'jwks_uri': 'http://oidcserver/keys',
|
||||||
|
})
|
||||||
|
|
||||||
|
@urlmatch(netloc=r'oidcserver', path='/token')
|
||||||
|
def account_handler(url, request):
|
||||||
|
if request.body.find("code=somecode") > 0:
|
||||||
|
return py_json.dumps({
|
||||||
|
'access_token': token,
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
||||||
|
|
||||||
|
@urlmatch(netloc=r'oidcserver', path='/keys')
|
||||||
|
def keys_handler(_, __):
|
||||||
|
return py_json.dumps({
|
||||||
|
"keys": [jwk],
|
||||||
|
})
|
||||||
|
|
||||||
|
with HTTMock(wellknown_handler, account_handler, keys_handler):
|
||||||
|
self.invoke_oauth_tests('dex_oauth_callback', 'dex_oauth_attach', 'dex',
|
||||||
|
'someid', 'someemail')
|
||||||
|
|
||||||
|
|
||||||
class WebEndpointTestCase(EndpointTestCase):
|
class WebEndpointTestCase(EndpointTestCase):
|
||||||
def test_index(self):
|
def test_index(self):
|
||||||
self.getResponse('web.index')
|
self.getResponse('web.index')
|
||||||
|
|
|
@ -75,3 +75,12 @@ class TestConfig(DefaultConfig):
|
||||||
INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem'
|
INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem'
|
||||||
|
|
||||||
PROMETHEUS_AGGREGATOR_URL = None
|
PROMETHEUS_AGGREGATOR_URL = None
|
||||||
|
|
||||||
|
FEATURE_GITHUB_LOGIN = True
|
||||||
|
FEATURE_GOOGLE_LOGIN = True
|
||||||
|
FEATURE_DEX_LOGIN = True
|
||||||
|
|
||||||
|
DEX_LOGIN_CONFIG = {
|
||||||
|
'CLIENT_ID': 'someclientid',
|
||||||
|
'OIDC_SERVER': 'https://oidcserver/',
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,11 @@ import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from cachetools import TTLCache
|
from cachetools import TTLCache
|
||||||
|
from cachetools.func import lru_cache
|
||||||
|
|
||||||
|
from cryptography.hazmat.backends import default_backend
|
||||||
|
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||||
|
|
||||||
from jwkest.jwk import KEYS
|
from jwkest.jwk import KEYS
|
||||||
from util import slash_join
|
from util import slash_join
|
||||||
|
|
||||||
|
@ -64,12 +69,14 @@ class OAuthConfig(object):
|
||||||
else:
|
else:
|
||||||
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
|
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
|
||||||
|
|
||||||
|
if get_access_token.status_code / 100 != 2:
|
||||||
|
return None
|
||||||
|
|
||||||
json_data = get_access_token.json()
|
json_data = get_access_token.json()
|
||||||
if not json_data:
|
if not json_data:
|
||||||
return ''
|
return None
|
||||||
|
|
||||||
token = json_data.get('access_token', '')
|
return json_data.get('access_token', None)
|
||||||
return token
|
|
||||||
|
|
||||||
|
|
||||||
class GithubOAuthConfig(OAuthConfig):
|
class GithubOAuthConfig(OAuthConfig):
|
||||||
|
@ -265,11 +272,15 @@ class OIDCConfig(OAuthConfig):
|
||||||
super(OIDCConfig, self).__init__(config, key_name)
|
super(OIDCConfig, self).__init__(config, key_name)
|
||||||
|
|
||||||
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
|
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
|
||||||
self._oidc_config = {}
|
self._config = config
|
||||||
self._http_client = config['HTTPCLIENT']
|
self._http_client = config['HTTPCLIENT']
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def _oidc_config(self):
|
||||||
if self.config.get('OIDC_SERVER'):
|
if self.config.get('OIDC_SERVER'):
|
||||||
self._load_via_discovery(config.get('DEBUGGING', False))
|
return self._load_via_discovery(self._config.get('DEBUGGING', False))
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
def _load_via_discovery(self, is_debugging):
|
def _load_via_discovery(self, is_debugging):
|
||||||
oidc_server = self.config['OIDC_SERVER']
|
oidc_server = self.config['OIDC_SERVER']
|
||||||
|
@ -283,16 +294,16 @@ class OIDCConfig(OAuthConfig):
|
||||||
raise Exception("Could not load OIDC discovery information")
|
raise Exception("Could not load OIDC discovery information")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._oidc_config = json.loads(discovery.text)
|
return json.loads(discovery.text)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
||||||
raise Exception("Could not parse OIDC discovery information")
|
raise Exception("Could not parse OIDC discovery information")
|
||||||
|
|
||||||
def authorize_endpoint(self):
|
def authorize_endpoint(self):
|
||||||
return self._oidc_config.get('authorization_endpoint', '') + '?'
|
return self._oidc_config().get('authorization_endpoint', '') + '?'
|
||||||
|
|
||||||
def token_endpoint(self):
|
def token_endpoint(self):
|
||||||
return self._oidc_config.get('token_endpoint')
|
return self._oidc_config().get('token_endpoint')
|
||||||
|
|
||||||
def user_endpoint(self):
|
def user_endpoint(self):
|
||||||
return None
|
return None
|
||||||
|
@ -322,9 +333,9 @@ class OIDCConfig(OAuthConfig):
|
||||||
# a random key chose to be stored in the cache, and could be anything.
|
# a random key chose to be stored in the cache, and could be anything.
|
||||||
return self._public_key_cache[None]
|
return self._public_key_cache[None]
|
||||||
|
|
||||||
def _get_public_key(self):
|
def _get_public_key(self, _):
|
||||||
""" Retrieves the public key for this handler. """
|
""" Retrieves the public key for this handler. """
|
||||||
keys_url = self._oidc_config['jwks_uri']
|
keys_url = self._oidc_config()['jwks_uri']
|
||||||
|
|
||||||
keys = KEYS()
|
keys = KEYS()
|
||||||
keys.load_from_url(keys_url)
|
keys.load_from_url(keys_url)
|
||||||
|
@ -334,7 +345,10 @@ class OIDCConfig(OAuthConfig):
|
||||||
|
|
||||||
rsa_key = list(keys)[0]
|
rsa_key = list(keys)[0]
|
||||||
rsa_key.deserialize()
|
rsa_key.deserialize()
|
||||||
return rsa_key.key.exportKey('PEM')
|
|
||||||
|
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
|
||||||
|
# issues.
|
||||||
|
return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())
|
||||||
|
|
||||||
|
|
||||||
class DexOAuthConfig(OIDCConfig):
|
class DexOAuthConfig(OIDCConfig):
|
||||||
|
|
Reference in a new issue