Merge pull request #1457 from coreos-inc/xauth
Add support for direct granting of OAuth tokens and add tests
This commit is contained in:
commit
a85c3ebff7
5 changed files with 216 additions and 18 deletions
57
auth/auth.py
57
auth/auth.py
|
@ -72,24 +72,37 @@ def _validate_and_apply_oauth_token(token):
|
|||
identity_changed.send(app, identity=new_identity)
|
||||
|
||||
|
||||
def _process_basic_auth(auth):
|
||||
def _parse_basic_auth_header(auth):
|
||||
normalized = [part.strip() for part in auth.split(' ') if part]
|
||||
if normalized[0].lower() != 'basic' or len(normalized) != 2:
|
||||
logger.debug('Invalid basic auth format.')
|
||||
return
|
||||
return None
|
||||
|
||||
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)]
|
||||
logger.debug('Found basic auth header: %s', auth)
|
||||
try:
|
||||
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)]
|
||||
except TypeError:
|
||||
logger.exception('Exception when parsing basic auth header')
|
||||
return None
|
||||
|
||||
if len(credentials) != 2:
|
||||
logger.debug('Invalid basic auth credential format.')
|
||||
return None
|
||||
|
||||
elif credentials[0] == '$token':
|
||||
return credentials
|
||||
|
||||
|
||||
def _process_basic_auth(auth):
|
||||
credentials = _parse_basic_auth_header(auth)
|
||||
if credentials is None:
|
||||
return
|
||||
|
||||
if credentials[0] == '$token':
|
||||
# Use as token auth
|
||||
try:
|
||||
token = model.token.load_token_data(credentials[1])
|
||||
logger.debug('Successfully validated token: %s', credentials[1])
|
||||
set_validated_token(token)
|
||||
|
||||
identity_changed.send(app, identity=Identity(token.code, 'token'))
|
||||
return
|
||||
|
||||
|
@ -117,7 +130,6 @@ def _process_basic_auth(auth):
|
|||
else:
|
||||
(authenticated, _) = authentication.verify_and_link_user(credentials[0], credentials[1],
|
||||
basic_auth=True)
|
||||
|
||||
if authenticated:
|
||||
logger.debug('Successfully validated user: %s', authenticated.username)
|
||||
set_authenticated_user(authenticated)
|
||||
|
@ -130,6 +142,23 @@ def _process_basic_auth(auth):
|
|||
logger.debug('Basic auth present but could not be validated.')
|
||||
|
||||
|
||||
def has_basic_auth(username):
|
||||
auth = request.headers.get('authorization', '')
|
||||
if not auth:
|
||||
return False
|
||||
|
||||
credentials = _parse_basic_auth_header(auth)
|
||||
if not credentials:
|
||||
return False
|
||||
|
||||
(authenticated, _) = authentication.verify_and_link_user(credentials[0], credentials[1],
|
||||
basic_auth=True)
|
||||
if not authenticated:
|
||||
return False
|
||||
|
||||
return authenticated.username == username
|
||||
|
||||
|
||||
def generate_signed_token(grants, user_context):
|
||||
ser = SecureCookieSessionInterface().get_signing_serializer(app)
|
||||
data_to_sign = {
|
||||
|
@ -209,6 +238,22 @@ def process_auth(func):
|
|||
return wrapper
|
||||
|
||||
|
||||
def process_auth_or_cookie(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
auth = request.headers.get('authorization', '')
|
||||
|
||||
if auth:
|
||||
logger.debug('Validating auth header: %s', auth)
|
||||
_process_basic_auth(auth)
|
||||
else:
|
||||
logger.debug('No auth header.')
|
||||
_load_user_from_cookie()
|
||||
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
|
||||
def require_session_login(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
|
|
|
@ -348,3 +348,7 @@ class DefaultConfig(object):
|
|||
|
||||
# Number of minutes between expiration refresh in minutes
|
||||
INSTANCE_SERVICE_KEY_REFRESH = 60
|
||||
|
||||
# The whitelist of client IDs for OAuth applications that allow for direct login.
|
||||
DIRECT_OAUTH_CLIENTID_WHITELIST = []
|
||||
|
||||
|
|
|
@ -168,22 +168,23 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|||
err = 'unsupported_response_type'
|
||||
return self._make_redirect_error_response(redirect_uri, err)
|
||||
|
||||
# Check redirect URI
|
||||
# Check for a valid client ID.
|
||||
is_valid_client_id = self.validate_client_id(client_id)
|
||||
if not is_valid_client_id:
|
||||
err = 'unauthorized_client'
|
||||
return self._make_redirect_error_response(redirect_uri, err)
|
||||
|
||||
# Check for a valid redirect URI.
|
||||
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
|
||||
if not is_valid_redirect_uri:
|
||||
return self._invalid_redirect_uri_response()
|
||||
|
||||
# Check conditions
|
||||
is_valid_client_id = self.validate_client_id(client_id)
|
||||
is_valid_access = self.validate_access()
|
||||
scope = params.get('scope', '')
|
||||
are_valid_scopes = self.validate_scope(client_id, scope)
|
||||
|
||||
# Return proper error responses on invalid conditions
|
||||
if not is_valid_client_id:
|
||||
err = 'unauthorized_client'
|
||||
return self._make_redirect_error_response(redirect_uri, err)
|
||||
|
||||
if not is_valid_access:
|
||||
err = 'access_denied'
|
||||
return self._make_redirect_error_response(redirect_uri, err)
|
||||
|
|
|
@ -12,10 +12,11 @@ import features
|
|||
|
||||
from app import app, billing as stripe, build_logs, avatar, signer, log_archive, config_provider
|
||||
from auth import scopes
|
||||
from auth.auth import require_session_login, process_oauth
|
||||
from auth.auth import require_session_login, process_oauth, has_basic_auth, process_auth_or_cookie
|
||||
from auth.permissions import (AdministerOrganizationPermission, ReadRepositoryPermission,
|
||||
SuperUserPermission, AdministerRepositoryPermission,
|
||||
ModifyRepositoryPermission)
|
||||
from auth.auth_context import get_authenticated_user
|
||||
from buildtrigger.basehandler import BuildTriggerHandler
|
||||
from buildtrigger.bitbuckethandler import BitbucketBuildTrigger
|
||||
from buildtrigger.customhandler import CustomBuildTrigger
|
||||
|
@ -452,21 +453,27 @@ def build_status_badge(namespace_name, repo_name):
|
|||
|
||||
class FlaskAuthorizationProvider(model.oauth.DatabaseAuthorizationProvider):
|
||||
def get_authorized_user(self):
|
||||
return current_user.db_user()
|
||||
return get_authenticated_user()
|
||||
|
||||
def _make_response(self, body='', headers=None, status_code=200):
|
||||
return make_response(body, status_code, headers)
|
||||
|
||||
|
||||
@web.route('/oauth/authorizeapp', methods=['POST'])
|
||||
@csrf_protect
|
||||
@process_auth_or_cookie
|
||||
def authorize_application():
|
||||
if not current_user.is_authenticated:
|
||||
# Check for an authenticated user.
|
||||
if not get_authenticated_user():
|
||||
abort(401)
|
||||
return
|
||||
|
||||
provider = FlaskAuthorizationProvider()
|
||||
# If direct OAuth is not enabled or the user is not directly authed, verify CSRF.
|
||||
client_id = request.form.get('client_id', None)
|
||||
whitelist = app.config.get('DIRECT_OAUTH_CLIENTID_WHITELIST', [])
|
||||
if client_id not in whitelist or not has_basic_auth(get_authenticated_user().username):
|
||||
verify_csrf()
|
||||
|
||||
provider = FlaskAuthorizationProvider()
|
||||
redirect_uri = request.form.get('redirect_uri', None)
|
||||
scope = request.form.get('scope', None)
|
||||
|
||||
|
@ -474,7 +481,6 @@ def authorize_application():
|
|||
return provider.get_token_response('token', client_id, redirect_uri, scope=scope)
|
||||
|
||||
|
||||
|
||||
@web.route(app.config['LOCAL_OAUTH_HANDLER'], methods=['GET'])
|
||||
def oauth_local_handler():
|
||||
if not current_user.is_authenticated:
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
import json as py_json
|
||||
import time
|
||||
import unittest
|
||||
import base64
|
||||
|
||||
from urllib import urlencode
|
||||
from urlparse import urlparse, urlunparse, parse_qs
|
||||
|
@ -86,6 +87,17 @@ class EndpointTestCase(unittest.TestCase):
|
|||
self.assertEquals(rv.status_code, expected_code)
|
||||
return rv.data
|
||||
|
||||
def postResponse(self, resource_name, headers=None, form=None, with_csrf=True, expected_code=200, **kwargs):
|
||||
headers = headers or {}
|
||||
form = form or {}
|
||||
url = url_for(resource_name, **kwargs)
|
||||
if with_csrf:
|
||||
url = EndpointTestCase._add_csrf(url)
|
||||
|
||||
rv = self.app.post(url, headers=headers, data=form)
|
||||
self.assertEquals(rv.status_code, expected_code)
|
||||
return rv
|
||||
|
||||
def login(self, username, password):
|
||||
rv = self.app.post(EndpointTestCase._add_csrf(api.url_for(Signin)),
|
||||
data=py_json.dumps(dict(username=username, password=password)),
|
||||
|
@ -191,6 +203,136 @@ class WebEndpointTestCase(EndpointTestCase):
|
|||
self.getResponse('web.redirect_to_namespace', namespace='buynlarge', expected_code=302)
|
||||
|
||||
|
||||
class OAuthTestCase(EndpointTestCase):
|
||||
def test_authorize_nologin(self):
|
||||
form = {
|
||||
'client_id': 'someclient',
|
||||
'redirect_uri': 'http://localhost:5000/foobar',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=401)
|
||||
|
||||
def test_authorize_invalidclient(self):
|
||||
self.login('devtable', 'password')
|
||||
|
||||
form = {
|
||||
'client_id': 'someclient',
|
||||
'redirect_uri': 'http://localhost:5000/foobar',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
|
||||
self.assertEquals('http://localhost:5000/foobar?error=unauthorized_client', resp.headers['Location'])
|
||||
|
||||
def test_authorize_invalidscope(self):
|
||||
self.login('devtable', 'password')
|
||||
|
||||
form = {
|
||||
'client_id': 'deadbeef',
|
||||
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||
'scope': 'invalid:scope',
|
||||
}
|
||||
|
||||
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
|
||||
self.assertEquals('http://localhost:8000/o2c.html?error=invalid_scope', resp.headers['Location'])
|
||||
|
||||
def test_authorize_invalidredirecturi(self):
|
||||
self.login('devtable', 'password')
|
||||
|
||||
# Note: Defined in initdb.py
|
||||
form = {
|
||||
'client_id': 'deadbeef',
|
||||
'redirect_uri': 'http://some/invalid/uri',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=400)
|
||||
|
||||
def test_authorize_success(self):
|
||||
self.login('devtable', 'password')
|
||||
|
||||
# Note: Defined in initdb.py
|
||||
form = {
|
||||
'client_id': 'deadbeef',
|
||||
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
|
||||
self.assertTrue('access_token=' in resp.headers['Location'])
|
||||
|
||||
def test_authorize_nocsrf(self):
|
||||
self.login('devtable', 'password')
|
||||
|
||||
# Note: Defined in initdb.py
|
||||
form = {
|
||||
'client_id': 'deadbeef',
|
||||
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
self.postResponse('web.authorize_application', form=form, with_csrf=False, expected_code=403)
|
||||
|
||||
def test_authorize_nocsrf_withinvalidheader(self):
|
||||
self.login('devtable', 'password')
|
||||
|
||||
# Note: Defined in initdb.py
|
||||
form = {
|
||||
'client_id': 'deadbeef',
|
||||
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
headers = dict(authorization='Some random header')
|
||||
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
|
||||
|
||||
def test_authorize_nocsrf_withbadheader(self):
|
||||
self.login('devtable', 'password')
|
||||
|
||||
# Note: Defined in initdb.py
|
||||
form = {
|
||||
'client_id': 'deadbeef',
|
||||
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
|
||||
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
|
||||
|
||||
def test_authorize_nocsrf_correctheader(self):
|
||||
# Note: Defined in initdb.py
|
||||
form = {
|
||||
'client_id': 'deadbeef',
|
||||
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
# Try without the client id being in the whitelist.
|
||||
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
|
||||
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=403)
|
||||
|
||||
# Add the client ID to the whitelist and try again.
|
||||
app.config['DIRECT_OAUTH_CLIENTID_WHITELIST'] = ['deadbeef']
|
||||
|
||||
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
|
||||
resp = self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=True, expected_code=302)
|
||||
self.assertTrue('access_token=' in resp.headers['Location'])
|
||||
|
||||
def test_authorize_nocsrf_ratelimiting(self):
|
||||
# Note: Defined in initdb.py
|
||||
form = {
|
||||
'client_id': 'deadbeef',
|
||||
'redirect_uri': 'http://localhost:8000/o2c.html',
|
||||
'scope': 'user:admin',
|
||||
}
|
||||
|
||||
# Try without the client id being in the whitelist a few times, making sure we eventually get rate limited.
|
||||
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
|
||||
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
|
||||
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=429)
|
||||
|
||||
|
||||
class KeyServerTestCase(EndpointTestCase):
|
||||
def _get_test_jwt_payload(self):
|
||||
return {
|
||||
|
|
Reference in a new issue