Merge pull request #1457 from coreos-inc/xauth
Add support for direct granting of OAuth tokens and add tests
This commit is contained in:
commit
a85c3ebff7
5 changed files with 216 additions and 18 deletions
57
auth/auth.py
57
auth/auth.py
|
@ -72,24 +72,37 @@ def _validate_and_apply_oauth_token(token):
|
||||||
identity_changed.send(app, identity=new_identity)
|
identity_changed.send(app, identity=new_identity)
|
||||||
|
|
||||||
|
|
||||||
def _process_basic_auth(auth):
|
def _parse_basic_auth_header(auth):
|
||||||
normalized = [part.strip() for part in auth.split(' ') if part]
|
normalized = [part.strip() for part in auth.split(' ') if part]
|
||||||
if normalized[0].lower() != 'basic' or len(normalized) != 2:
|
if normalized[0].lower() != 'basic' or len(normalized) != 2:
|
||||||
logger.debug('Invalid basic auth format.')
|
logger.debug('Invalid basic auth format.')
|
||||||
return
|
return None
|
||||||
|
|
||||||
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)]
|
logger.debug('Found basic auth header: %s', auth)
|
||||||
|
try:
|
||||||
|
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)]
|
||||||
|
except TypeError:
|
||||||
|
logger.exception('Exception when parsing basic auth header')
|
||||||
|
return None
|
||||||
|
|
||||||
if len(credentials) != 2:
|
if len(credentials) != 2:
|
||||||
logger.debug('Invalid basic auth credential format.')
|
logger.debug('Invalid basic auth credential format.')
|
||||||
|
return None
|
||||||
|
|
||||||
elif credentials[0] == '$token':
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
|
def _process_basic_auth(auth):
|
||||||
|
credentials = _parse_basic_auth_header(auth)
|
||||||
|
if credentials is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if credentials[0] == '$token':
|
||||||
# Use as token auth
|
# Use as token auth
|
||||||
try:
|
try:
|
||||||
token = model.token.load_token_data(credentials[1])
|
token = model.token.load_token_data(credentials[1])
|
||||||
logger.debug('Successfully validated token: %s', credentials[1])
|
logger.debug('Successfully validated token: %s', credentials[1])
|
||||||
set_validated_token(token)
|
set_validated_token(token)
|
||||||
|
|
||||||
identity_changed.send(app, identity=Identity(token.code, 'token'))
|
identity_changed.send(app, identity=Identity(token.code, 'token'))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -117,7 +130,6 @@ def _process_basic_auth(auth):
|
||||||
else:
|
else:
|
||||||
(authenticated, _) = authentication.verify_and_link_user(credentials[0], credentials[1],
|
(authenticated, _) = authentication.verify_and_link_user(credentials[0], credentials[1],
|
||||||
basic_auth=True)
|
basic_auth=True)
|
||||||
|
|
||||||
if authenticated:
|
if authenticated:
|
||||||
logger.debug('Successfully validated user: %s', authenticated.username)
|
logger.debug('Successfully validated user: %s', authenticated.username)
|
||||||
set_authenticated_user(authenticated)
|
set_authenticated_user(authenticated)
|
||||||
|
@ -130,6 +142,23 @@ def _process_basic_auth(auth):
|
||||||
logger.debug('Basic auth present but could not be validated.')
|
logger.debug('Basic auth present but could not be validated.')
|
||||||
|
|
||||||
|
|
||||||
|
def has_basic_auth(username):
|
||||||
|
auth = request.headers.get('authorization', '')
|
||||||
|
if not auth:
|
||||||
|
return False
|
||||||
|
|
||||||
|
credentials = _parse_basic_auth_header(auth)
|
||||||
|
if not credentials:
|
||||||
|
return False
|
||||||
|
|
||||||
|
(authenticated, _) = authentication.verify_and_link_user(credentials[0], credentials[1],
|
||||||
|
basic_auth=True)
|
||||||
|
if not authenticated:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return authenticated.username == username
|
||||||
|
|
||||||
|
|
||||||
def generate_signed_token(grants, user_context):
|
def generate_signed_token(grants, user_context):
|
||||||
ser = SecureCookieSessionInterface().get_signing_serializer(app)
|
ser = SecureCookieSessionInterface().get_signing_serializer(app)
|
||||||
data_to_sign = {
|
data_to_sign = {
|
||||||
|
@ -209,6 +238,22 @@ def process_auth(func):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def process_auth_or_cookie(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
auth = request.headers.get('authorization', '')
|
||||||
|
|
||||||
|
if auth:
|
||||||
|
logger.debug('Validating auth header: %s', auth)
|
||||||
|
_process_basic_auth(auth)
|
||||||
|
else:
|
||||||
|
logger.debug('No auth header.')
|
||||||
|
_load_user_from_cookie()
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def require_session_login(func):
|
def require_session_login(func):
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
|
|
|
@ -348,3 +348,7 @@ class DefaultConfig(object):
|
||||||
|
|
||||||
# Number of minutes between expiration refresh in minutes
|
# Number of minutes between expiration refresh in minutes
|
||||||
INSTANCE_SERVICE_KEY_REFRESH = 60
|
INSTANCE_SERVICE_KEY_REFRESH = 60
|
||||||
|
|
||||||
|
# The whitelist of client IDs for OAuth applications that allow for direct login.
|
||||||
|
DIRECT_OAUTH_CLIENTID_WHITELIST = []
|
||||||
|
|
||||||
|
|
|
@ -168,22 +168,23 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
||||||
err = 'unsupported_response_type'
|
err = 'unsupported_response_type'
|
||||||
return self._make_redirect_error_response(redirect_uri, err)
|
return self._make_redirect_error_response(redirect_uri, err)
|
||||||
|
|
||||||
# Check redirect URI
|
# Check for a valid client ID.
|
||||||
|
is_valid_client_id = self.validate_client_id(client_id)
|
||||||
|
if not is_valid_client_id:
|
||||||
|
err = 'unauthorized_client'
|
||||||
|
return self._make_redirect_error_response(redirect_uri, err)
|
||||||
|
|
||||||
|
# Check for a valid redirect URI.
|
||||||
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
|
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
|
||||||
if not is_valid_redirect_uri:
|
if not is_valid_redirect_uri:
|
||||||
return self._invalid_redirect_uri_response()
|
return self._invalid_redirect_uri_response()
|
||||||
|
|
||||||
# Check conditions
|
# Check conditions
|
||||||
is_valid_client_id = self.validate_client_id(client_id)
|
|
||||||
is_valid_access = self.validate_access()
|
is_valid_access = self.validate_access()
|
||||||
scope = params.get('scope', '')
|
scope = params.get('scope', '')
|
||||||
are_valid_scopes = self.validate_scope(client_id, scope)
|
are_valid_scopes = self.validate_scope(client_id, scope)
|
||||||
|
|
||||||
# Return proper error responses on invalid conditions
|
# Return proper error responses on invalid conditions
|
||||||
if not is_valid_client_id:
|
|
||||||
err = 'unauthorized_client'
|
|
||||||
return self._make_redirect_error_response(redirect_uri, err)
|
|
||||||
|
|
||||||
if not is_valid_access:
|
if not is_valid_access:
|
||||||
err = 'access_denied'
|
err = 'access_denied'
|
||||||
return self._make_redirect_error_response(redirect_uri, err)
|
return self._make_redirect_error_response(redirect_uri, err)
|
||||||
|
|
|
@ -12,10 +12,11 @@ import features
|
||||||
|
|
||||||
from app import app, billing as stripe, build_logs, avatar, signer, log_archive, config_provider
|
from app import app, billing as stripe, build_logs, avatar, signer, log_archive, config_provider
|
||||||
from auth import scopes
|
from auth import scopes
|
||||||
from auth.auth import require_session_login, process_oauth
|
from auth.auth import require_session_login, process_oauth, has_basic_auth, process_auth_or_cookie
|
||||||
from auth.permissions import (AdministerOrganizationPermission, ReadRepositoryPermission,
|
from auth.permissions import (AdministerOrganizationPermission, ReadRepositoryPermission,
|
||||||
SuperUserPermission, AdministerRepositoryPermission,
|
SuperUserPermission, AdministerRepositoryPermission,
|
||||||
ModifyRepositoryPermission)
|
ModifyRepositoryPermission)
|
||||||
|
from auth.auth_context import get_authenticated_user
|
||||||
from buildtrigger.basehandler import BuildTriggerHandler
|
from buildtrigger.basehandler import BuildTriggerHandler
|
||||||
from buildtrigger.bitbuckethandler import BitbucketBuildTrigger
|
from buildtrigger.bitbuckethandler import BitbucketBuildTrigger
|
||||||
from buildtrigger.customhandler import CustomBuildTrigger
|
from buildtrigger.customhandler import CustomBuildTrigger
|
||||||
|
@ -452,21 +453,27 @@ def build_status_badge(namespace_name, repo_name):
|
||||||
|
|
||||||
class FlaskAuthorizationProvider(model.oauth.DatabaseAuthorizationProvider):
|
class FlaskAuthorizationProvider(model.oauth.DatabaseAuthorizationProvider):
|
||||||
def get_authorized_user(self):
|
def get_authorized_user(self):
|
||||||
return current_user.db_user()
|
return get_authenticated_user()
|
||||||
|
|
||||||
def _make_response(self, body='', headers=None, status_code=200):
|
def _make_response(self, body='', headers=None, status_code=200):
|
||||||
return make_response(body, status_code, headers)
|
return make_response(body, status_code, headers)
|
||||||
|
|
||||||
|
|
||||||
@web.route('/oauth/authorizeapp', methods=['POST'])
|
@web.route('/oauth/authorizeapp', methods=['POST'])
|
||||||
@csrf_protect
|
@process_auth_or_cookie
|
||||||
def authorize_application():
|
def authorize_application():
|
||||||
if not current_user.is_authenticated:
|
# Check for an authenticated user.
|
||||||
|
if not get_authenticated_user():
|
||||||
abort(401)
|
abort(401)
|
||||||
return
|
return
|
||||||
|
|
||||||
provider = FlaskAuthorizationProvider()
|
# If direct OAuth is not enabled or the user is not directly authed, verify CSRF.
|
||||||
client_id = request.form.get('client_id', None)
|
client_id = request.form.get('client_id', None)
|
||||||
|
whitelist = app.config.get('DIRECT_OAUTH_CLIENTID_WHITELIST', [])
|
||||||
|
if client_id not in whitelist or not has_basic_auth(get_authenticated_user().username):
|
||||||
|
verify_csrf()
|
||||||
|
|
||||||
|
provider = FlaskAuthorizationProvider()
|
||||||
redirect_uri = request.form.get('redirect_uri', None)
|
redirect_uri = request.form.get('redirect_uri', None)
|
||||||
scope = request.form.get('scope', None)
|
scope = request.form.get('scope', None)
|
||||||
|
|
||||||
|
@ -474,7 +481,6 @@ def authorize_application():
|
||||||
return provider.get_token_response('token', client_id, redirect_uri, scope=scope)
|
return provider.get_token_response('token', client_id, redirect_uri, scope=scope)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@web.route(app.config['LOCAL_OAUTH_HANDLER'], methods=['GET'])
|
@web.route(app.config['LOCAL_OAUTH_HANDLER'], methods=['GET'])
|
||||||
def oauth_local_handler():
|
def oauth_local_handler():
|
||||||
if not current_user.is_authenticated:
|
if not current_user.is_authenticated:
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
import json as py_json
|
import json as py_json
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
import base64
|
||||||
|
|
||||||
from urllib import urlencode
|
from urllib import urlencode
|
||||||
from urlparse import urlparse, urlunparse, parse_qs
|
from urlparse import urlparse, urlunparse, parse_qs
|
||||||
|
@ -86,6 +87,17 @@ class EndpointTestCase(unittest.TestCase):
|
||||||
self.assertEquals(rv.status_code, expected_code)
|
self.assertEquals(rv.status_code, expected_code)
|
||||||
return rv.data
|
return rv.data
|
||||||
|
|
||||||
|
def postResponse(self, resource_name, headers=None, form=None, with_csrf=True, expected_code=200, **kwargs):
|
||||||
|
headers = headers or {}
|
||||||
|
form = form or {}
|
||||||
|
url = url_for(resource_name, **kwargs)
|
||||||
|
if with_csrf:
|
||||||
|
url = EndpointTestCase._add_csrf(url)
|
||||||
|
|
||||||
|
rv = self.app.post(url, headers=headers, data=form)
|
||||||
|
self.assertEquals(rv.status_code, expected_code)
|
||||||
|
return rv
|
||||||
|
|
||||||
def login(self, username, password):
|
def login(self, username, password):
|
||||||
rv = self.app.post(EndpointTestCase._add_csrf(api.url_for(Signin)),
|
rv = self.app.post(EndpointTestCase._add_csrf(api.url_for(Signin)),
|
||||||
data=py_json.dumps(dict(username=username, password=password)),
|
data=py_json.dumps(dict(username=username, password=password)),
|
||||||
|
@ -191,6 +203,136 @@ class WebEndpointTestCase(EndpointTestCase):
|
||||||
self.getResponse('web.redirect_to_namespace', namespace='buynlarge', expected_code=302)
|
self.getResponse('web.redirect_to_namespace', namespace='buynlarge', expected_code=302)
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthTestCase(EndpointTestCase):
|
||||||
|
def test_authorize_nologin(self):
|
||||||
|
form = {
|
||||||
|
'client_id': 'someclient',
|
||||||
|
'redirect_uri': 'http://localhost:5000/foobar',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=401)
|
||||||
|
|
||||||
|
def test_authorize_invalidclient(self):
|
||||||
|
self.login('devtable', 'password')
|
||||||
|
|
||||||
|
form = {
|
||||||
|
'client_id': 'someclient',
|
||||||
|
'redirect_uri': 'http://localhost:5000/foobar',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
|
||||||
|
self.assertEquals('http://localhost:5000/foobar?error=unauthorized_client', resp.headers['Location'])
|
||||||
|
|
||||||
|
def test_authorize_invalidscope(self):
|
||||||
|
self.login('devtable', 'password')
|
||||||
|
|
||||||
|
form = {
|
||||||
|
'client_id': 'deadbeef',
|
||||||
|
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||||
|
'scope': 'invalid:scope',
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
|
||||||
|
self.assertEquals('http://localhost:8000/o2c.html?error=invalid_scope', resp.headers['Location'])
|
||||||
|
|
||||||
|
def test_authorize_invalidredirecturi(self):
|
||||||
|
self.login('devtable', 'password')
|
||||||
|
|
||||||
|
# Note: Defined in initdb.py
|
||||||
|
form = {
|
||||||
|
'client_id': 'deadbeef',
|
||||||
|
'redirect_uri': 'http://some/invalid/uri',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=400)
|
||||||
|
|
||||||
|
def test_authorize_success(self):
|
||||||
|
self.login('devtable', 'password')
|
||||||
|
|
||||||
|
# Note: Defined in initdb.py
|
||||||
|
form = {
|
||||||
|
'client_id': 'deadbeef',
|
||||||
|
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
|
||||||
|
self.assertTrue('access_token=' in resp.headers['Location'])
|
||||||
|
|
||||||
|
def test_authorize_nocsrf(self):
|
||||||
|
self.login('devtable', 'password')
|
||||||
|
|
||||||
|
# Note: Defined in initdb.py
|
||||||
|
form = {
|
||||||
|
'client_id': 'deadbeef',
|
||||||
|
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
self.postResponse('web.authorize_application', form=form, with_csrf=False, expected_code=403)
|
||||||
|
|
||||||
|
def test_authorize_nocsrf_withinvalidheader(self):
|
||||||
|
self.login('devtable', 'password')
|
||||||
|
|
||||||
|
# Note: Defined in initdb.py
|
||||||
|
form = {
|
||||||
|
'client_id': 'deadbeef',
|
||||||
|
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = dict(authorization='Some random header')
|
||||||
|
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
|
||||||
|
|
||||||
|
def test_authorize_nocsrf_withbadheader(self):
|
||||||
|
self.login('devtable', 'password')
|
||||||
|
|
||||||
|
# Note: Defined in initdb.py
|
||||||
|
form = {
|
||||||
|
'client_id': 'deadbeef',
|
||||||
|
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
|
||||||
|
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
|
||||||
|
|
||||||
|
def test_authorize_nocsrf_correctheader(self):
|
||||||
|
# Note: Defined in initdb.py
|
||||||
|
form = {
|
||||||
|
'client_id': 'deadbeef',
|
||||||
|
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try without the client id being in the whitelist.
|
||||||
|
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
|
||||||
|
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=403)
|
||||||
|
|
||||||
|
# Add the client ID to the whitelist and try again.
|
||||||
|
app.config['DIRECT_OAUTH_CLIENTID_WHITELIST'] = ['deadbeef']
|
||||||
|
|
||||||
|
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
|
||||||
|
resp = self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=True, expected_code=302)
|
||||||
|
self.assertTrue('access_token=' in resp.headers['Location'])
|
||||||
|
|
||||||
|
def test_authorize_nocsrf_ratelimiting(self):
|
||||||
|
# Note: Defined in initdb.py
|
||||||
|
form = {
|
||||||
|
'client_id': 'deadbeef',
|
||||||
|
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||||
|
'scope': 'user:admin',
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try without the client id being in the whitelist a few times, making sure we eventually get rate limited.
|
||||||
|
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
|
||||||
|
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
|
||||||
|
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=429)
|
||||||
|
|
||||||
|
|
||||||
class KeyServerTestCase(EndpointTestCase):
|
class KeyServerTestCase(EndpointTestCase):
|
||||||
def _get_test_jwt_payload(self):
|
def _get_test_jwt_payload(self):
|
||||||
return {
|
return {
|
||||||
|
|
Reference in a new issue