Merge pull request #1457 from coreos-inc/xauth

Add support for direct granting of OAuth tokens and add tests
This commit is contained in:
josephschorr 2016-06-01 12:07:12 -04:00
commit a85c3ebff7
5 changed files with 216 additions and 18 deletions

View file

@ -72,24 +72,37 @@ def _validate_and_apply_oauth_token(token):
identity_changed.send(app, identity=new_identity) identity_changed.send(app, identity=new_identity)
def _process_basic_auth(auth): def _parse_basic_auth_header(auth):
normalized = [part.strip() for part in auth.split(' ') if part] normalized = [part.strip() for part in auth.split(' ') if part]
if normalized[0].lower() != 'basic' or len(normalized) != 2: if normalized[0].lower() != 'basic' or len(normalized) != 2:
logger.debug('Invalid basic auth format.') logger.debug('Invalid basic auth format.')
return return None
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)] logger.debug('Found basic auth header: %s', auth)
try:
credentials = [part.decode('utf-8') for part in b64decode(normalized[1]).split(':', 1)]
except TypeError:
logger.exception('Exception when parsing basic auth header')
return None
if len(credentials) != 2: if len(credentials) != 2:
logger.debug('Invalid basic auth credential format.') logger.debug('Invalid basic auth credential format.')
return None
elif credentials[0] == '$token': return credentials
def _process_basic_auth(auth):
credentials = _parse_basic_auth_header(auth)
if credentials is None:
return
if credentials[0] == '$token':
# Use as token auth # Use as token auth
try: try:
token = model.token.load_token_data(credentials[1]) token = model.token.load_token_data(credentials[1])
logger.debug('Successfully validated token: %s', credentials[1]) logger.debug('Successfully validated token: %s', credentials[1])
set_validated_token(token) set_validated_token(token)
identity_changed.send(app, identity=Identity(token.code, 'token')) identity_changed.send(app, identity=Identity(token.code, 'token'))
return return
@ -117,7 +130,6 @@ def _process_basic_auth(auth):
else: else:
(authenticated, _) = authentication.verify_and_link_user(credentials[0], credentials[1], (authenticated, _) = authentication.verify_and_link_user(credentials[0], credentials[1],
basic_auth=True) basic_auth=True)
if authenticated: if authenticated:
logger.debug('Successfully validated user: %s', authenticated.username) logger.debug('Successfully validated user: %s', authenticated.username)
set_authenticated_user(authenticated) set_authenticated_user(authenticated)
@ -130,6 +142,23 @@ def _process_basic_auth(auth):
logger.debug('Basic auth present but could not be validated.') logger.debug('Basic auth present but could not be validated.')
def has_basic_auth(username):
auth = request.headers.get('authorization', '')
if not auth:
return False
credentials = _parse_basic_auth_header(auth)
if not credentials:
return False
(authenticated, _) = authentication.verify_and_link_user(credentials[0], credentials[1],
basic_auth=True)
if not authenticated:
return False
return authenticated.username == username
def generate_signed_token(grants, user_context): def generate_signed_token(grants, user_context):
ser = SecureCookieSessionInterface().get_signing_serializer(app) ser = SecureCookieSessionInterface().get_signing_serializer(app)
data_to_sign = { data_to_sign = {
@ -209,6 +238,22 @@ def process_auth(func):
return wrapper return wrapper
def process_auth_or_cookie(func):
@wraps(func)
def wrapper(*args, **kwargs):
auth = request.headers.get('authorization', '')
if auth:
logger.debug('Validating auth header: %s', auth)
_process_basic_auth(auth)
else:
logger.debug('No auth header.')
_load_user_from_cookie()
return func(*args, **kwargs)
return wrapper
def require_session_login(func): def require_session_login(func):
@wraps(func) @wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):

View file

@ -348,3 +348,7 @@ class DefaultConfig(object):
# Number of minutes between expiration refresh in minutes # Number of minutes between expiration refresh in minutes
INSTANCE_SERVICE_KEY_REFRESH = 60 INSTANCE_SERVICE_KEY_REFRESH = 60
# The whitelist of client IDs for OAuth applications that allow for direct login.
DIRECT_OAUTH_CLIENTID_WHITELIST = []

View file

@ -168,22 +168,23 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
err = 'unsupported_response_type' err = 'unsupported_response_type'
return self._make_redirect_error_response(redirect_uri, err) return self._make_redirect_error_response(redirect_uri, err)
# Check redirect URI # Check for a valid client ID.
is_valid_client_id = self.validate_client_id(client_id)
if not is_valid_client_id:
err = 'unauthorized_client'
return self._make_redirect_error_response(redirect_uri, err)
# Check for a valid redirect URI.
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri) is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
if not is_valid_redirect_uri: if not is_valid_redirect_uri:
return self._invalid_redirect_uri_response() return self._invalid_redirect_uri_response()
# Check conditions # Check conditions
is_valid_client_id = self.validate_client_id(client_id)
is_valid_access = self.validate_access() is_valid_access = self.validate_access()
scope = params.get('scope', '') scope = params.get('scope', '')
are_valid_scopes = self.validate_scope(client_id, scope) are_valid_scopes = self.validate_scope(client_id, scope)
# Return proper error responses on invalid conditions # Return proper error responses on invalid conditions
if not is_valid_client_id:
err = 'unauthorized_client'
return self._make_redirect_error_response(redirect_uri, err)
if not is_valid_access: if not is_valid_access:
err = 'access_denied' err = 'access_denied'
return self._make_redirect_error_response(redirect_uri, err) return self._make_redirect_error_response(redirect_uri, err)

View file

@ -12,10 +12,11 @@ import features
from app import app, billing as stripe, build_logs, avatar, signer, log_archive, config_provider from app import app, billing as stripe, build_logs, avatar, signer, log_archive, config_provider
from auth import scopes from auth import scopes
from auth.auth import require_session_login, process_oauth from auth.auth import require_session_login, process_oauth, has_basic_auth, process_auth_or_cookie
from auth.permissions import (AdministerOrganizationPermission, ReadRepositoryPermission, from auth.permissions import (AdministerOrganizationPermission, ReadRepositoryPermission,
SuperUserPermission, AdministerRepositoryPermission, SuperUserPermission, AdministerRepositoryPermission,
ModifyRepositoryPermission) ModifyRepositoryPermission)
from auth.auth_context import get_authenticated_user
from buildtrigger.basehandler import BuildTriggerHandler from buildtrigger.basehandler import BuildTriggerHandler
from buildtrigger.bitbuckethandler import BitbucketBuildTrigger from buildtrigger.bitbuckethandler import BitbucketBuildTrigger
from buildtrigger.customhandler import CustomBuildTrigger from buildtrigger.customhandler import CustomBuildTrigger
@ -452,21 +453,27 @@ def build_status_badge(namespace_name, repo_name):
class FlaskAuthorizationProvider(model.oauth.DatabaseAuthorizationProvider): class FlaskAuthorizationProvider(model.oauth.DatabaseAuthorizationProvider):
def get_authorized_user(self): def get_authorized_user(self):
return current_user.db_user() return get_authenticated_user()
def _make_response(self, body='', headers=None, status_code=200): def _make_response(self, body='', headers=None, status_code=200):
return make_response(body, status_code, headers) return make_response(body, status_code, headers)
@web.route('/oauth/authorizeapp', methods=['POST']) @web.route('/oauth/authorizeapp', methods=['POST'])
@csrf_protect @process_auth_or_cookie
def authorize_application(): def authorize_application():
if not current_user.is_authenticated: # Check for an authenticated user.
if not get_authenticated_user():
abort(401) abort(401)
return return
provider = FlaskAuthorizationProvider() # If direct OAuth is not enabled or the user is not directly authed, verify CSRF.
client_id = request.form.get('client_id', None) client_id = request.form.get('client_id', None)
whitelist = app.config.get('DIRECT_OAUTH_CLIENTID_WHITELIST', [])
if client_id not in whitelist or not has_basic_auth(get_authenticated_user().username):
verify_csrf()
provider = FlaskAuthorizationProvider()
redirect_uri = request.form.get('redirect_uri', None) redirect_uri = request.form.get('redirect_uri', None)
scope = request.form.get('scope', None) scope = request.form.get('scope', None)
@ -474,7 +481,6 @@ def authorize_application():
return provider.get_token_response('token', client_id, redirect_uri, scope=scope) return provider.get_token_response('token', client_id, redirect_uri, scope=scope)
@web.route(app.config['LOCAL_OAUTH_HANDLER'], methods=['GET']) @web.route(app.config['LOCAL_OAUTH_HANDLER'], methods=['GET'])
def oauth_local_handler(): def oauth_local_handler():
if not current_user.is_authenticated: if not current_user.is_authenticated:

View file

@ -3,6 +3,7 @@
import json as py_json import json as py_json
import time import time
import unittest import unittest
import base64
from urllib import urlencode from urllib import urlencode
from urlparse import urlparse, urlunparse, parse_qs from urlparse import urlparse, urlunparse, parse_qs
@ -86,6 +87,17 @@ class EndpointTestCase(unittest.TestCase):
self.assertEquals(rv.status_code, expected_code) self.assertEquals(rv.status_code, expected_code)
return rv.data return rv.data
def postResponse(self, resource_name, headers=None, form=None, with_csrf=True, expected_code=200, **kwargs):
headers = headers or {}
form = form or {}
url = url_for(resource_name, **kwargs)
if with_csrf:
url = EndpointTestCase._add_csrf(url)
rv = self.app.post(url, headers=headers, data=form)
self.assertEquals(rv.status_code, expected_code)
return rv
def login(self, username, password): def login(self, username, password):
rv = self.app.post(EndpointTestCase._add_csrf(api.url_for(Signin)), rv = self.app.post(EndpointTestCase._add_csrf(api.url_for(Signin)),
data=py_json.dumps(dict(username=username, password=password)), data=py_json.dumps(dict(username=username, password=password)),
@ -191,6 +203,136 @@ class WebEndpointTestCase(EndpointTestCase):
self.getResponse('web.redirect_to_namespace', namespace='buynlarge', expected_code=302) self.getResponse('web.redirect_to_namespace', namespace='buynlarge', expected_code=302)
class OAuthTestCase(EndpointTestCase):
def test_authorize_nologin(self):
form = {
'client_id': 'someclient',
'redirect_uri': 'http://localhost:5000/foobar',
'scope': 'user:admin',
}
self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=401)
def test_authorize_invalidclient(self):
self.login('devtable', 'password')
form = {
'client_id': 'someclient',
'redirect_uri': 'http://localhost:5000/foobar',
'scope': 'user:admin',
}
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
self.assertEquals('http://localhost:5000/foobar?error=unauthorized_client', resp.headers['Location'])
def test_authorize_invalidscope(self):
self.login('devtable', 'password')
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'invalid:scope',
}
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
self.assertEquals('http://localhost:8000/o2c.html?error=invalid_scope', resp.headers['Location'])
def test_authorize_invalidredirecturi(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://some/invalid/uri',
'scope': 'user:admin',
}
self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=400)
def test_authorize_success(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
resp = self.postResponse('web.authorize_application', form=form, with_csrf=True, expected_code=302)
self.assertTrue('access_token=' in resp.headers['Location'])
def test_authorize_nocsrf(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
self.postResponse('web.authorize_application', form=form, with_csrf=False, expected_code=403)
def test_authorize_nocsrf_withinvalidheader(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
headers = dict(authorization='Some random header')
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
def test_authorize_nocsrf_withbadheader(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
def test_authorize_nocsrf_correctheader(self):
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
# Try without the client id being in the whitelist.
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=403)
# Add the client ID to the whitelist and try again.
app.config['DIRECT_OAUTH_CLIENTID_WHITELIST'] = ['deadbeef']
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
resp = self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=True, expected_code=302)
self.assertTrue('access_token=' in resp.headers['Location'])
def test_authorize_nocsrf_ratelimiting(self):
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
'redirect_uri': 'http://localhost:8000/o2c.html',
'scope': 'user:admin',
}
# Try without the client id being in the whitelist a few times, making sure we eventually get rate limited.
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=429)
class KeyServerTestCase(EndpointTestCase): class KeyServerTestCase(EndpointTestCase):
def _get_test_jwt_payload(self): def _get_test_jwt_payload(self):
return { return {