Refactor our auth handling code to be cleaner

Breaks out the validation code from the auth context modification calls, makes decorators easier to define and adds testing for each individual piece. Will be the basis of better error messaging in the following change.
This commit is contained in:
Joseph Schorr 2017-03-16 17:05:26 -04:00
parent 1bd4422da9
commit 651666b60b
18 changed files with 830 additions and 455 deletions

View file

@ -1,185 +0,0 @@
import base64
import unittest
from datetime import datetime, timedelta
from flask import g
from flask_principal import identity_loaded
from app import app
from auth.scopes import (scopes_from_scope_string, is_subset_string, DIRECT_LOGIN, ADMIN_REPO,
ALL_SCOPES)
from auth.permissions import QuayDeferredPermissionUser
from auth.process import _process_basic_auth
from data import model
from data.database import OAuthApplication, OAuthAccessToken
from endpoints.api import api
from endpoints.api.user import User, Signin
from test.test_api_usage import ApiTestCase
ADMIN_ACCESS_USER = 'devtable'
DISABLED_USER = 'disabled'
@identity_loaded.connect_via(app)
def on_identity_loaded(sender, identity):
g.identity = identity
class TestAuth(ApiTestCase):
def verify_cookie_auth(self, username):
resp = self.getJsonResponse(User)
self.assertEquals(resp['username'], username)
def verify_identity(self, id):
try:
identity = g.identity
except:
identity = None
self.assertIsNotNone(identity)
self.assertEquals(identity.id, id)
def verify_no_identity(self):
try:
identity = g.identity
except:
identity = None
self.assertIsNone(identity)
def conduct_basic_auth(self, username, password):
encoded = base64.b64encode(username + ':' + password)
try:
_process_basic_auth('Basic ' + encoded)
except:
pass
def create_oauth(self, user):
oauth_app = OAuthApplication.create(client_id='onetwothree', redirect_uri='',
application_uri='', organization=user,
name='someapp')
expires_at = datetime.utcnow() + timedelta(seconds=50000)
OAuthAccessToken.create(application=oauth_app, authorized_user=user,
scope='repo:admin',
access_token='access1234', token_type='Bearer',
expires_at=expires_at, refresh_token=None, data={})
def test_login(self):
password = 'password'
resp = self.postJsonResponse(Signin, data=dict(username=ADMIN_ACCESS_USER, password=password))
self.assertTrue(resp.get('success'))
self.verify_cookie_auth(ADMIN_ACCESS_USER)
def test_login_disabled(self):
password = 'password'
self.postJsonResponse(Signin, data=dict(username=DISABLED_USER, password=password),
expected_code=403)
def test_basic_auth_user(self):
user = model.user.get_user(ADMIN_ACCESS_USER)
self.conduct_basic_auth(ADMIN_ACCESS_USER, 'password')
self.verify_identity(user.uuid)
def test_basic_auth_disabled_user(self):
user = model.user.get_user(DISABLED_USER)
self.conduct_basic_auth(DISABLED_USER, 'password')
self.verify_no_identity()
def test_basic_auth_token(self):
token = model.token.create_delegate_token(ADMIN_ACCESS_USER, 'simple', 'sometoken')
self.conduct_basic_auth('$token', token.code)
self.verify_identity(token.code)
def test_basic_auth_invalid_token(self):
self.conduct_basic_auth('$token', 'foobar')
self.verify_no_identity()
def test_basic_auth_invalid_user(self):
self.conduct_basic_auth('foobarinvalid', 'foobar')
self.verify_no_identity()
def test_oauth_invalid(self):
self.conduct_basic_auth('$oauthtoken', 'foobar')
self.verify_no_identity()
def test_oauth_invalid_http_response(self):
rv = self.app.get(api.url_for(User), headers={'Authorization': 'Bearer bad_token'})
assert 'WWW-Authenticate' in rv.headers
self.assertEquals(401, rv.status_code)
def test_oauth_valid_user(self):
user = model.user.get_user(ADMIN_ACCESS_USER)
self.create_oauth(user)
self.conduct_basic_auth('$oauthtoken', 'access1234')
self.verify_identity(user.uuid)
def test_oauth_disabled_user(self):
user = model.user.get_user(DISABLED_USER)
self.create_oauth(user)
self.conduct_basic_auth('$oauthtoken', 'access1234')
self.verify_no_identity()
def test_basic_auth_robot(self):
user = model.user.get_user(ADMIN_ACCESS_USER)
robot, passcode = model.user.get_robot('dtrobot', user)
self.conduct_basic_auth(robot.username, passcode)
self.verify_identity(robot.uuid)
def test_basic_auth_robot_invalidcode(self):
user = model.user.get_user(ADMIN_ACCESS_USER)
robot, _ = model.user.get_robot('dtrobot', user)
self.conduct_basic_auth(robot.username, 'someinvalidcode')
self.verify_no_identity()
def test_deferred_permissions_scopes(self):
self.assertEquals(QuayDeferredPermissionUser.for_id('123454')._scope_set, {DIRECT_LOGIN})
self.assertEquals(QuayDeferredPermissionUser.for_id('123454', {})._scope_set, {})
self.assertEquals(QuayDeferredPermissionUser.for_id('123454', {ADMIN_REPO})._scope_set, {ADMIN_REPO})
def assertParsedScopes(self, scopes_str, *args):
expected_scope_set = {ALL_SCOPES[scope_name] for scope_name in args}
parsed_scope_set = scopes_from_scope_string(scopes_str)
self.assertEquals(parsed_scope_set, expected_scope_set)
def test_scopes_parsing(self):
# Valid single scopes.
self.assertParsedScopes('repo:read', 'repo:read')
self.assertParsedScopes('repo:admin', 'repo:admin')
# Invalid scopes.
self.assertParsedScopes('not:valid')
self.assertParsedScopes('repo:admins')
# Valid scope strings.
self.assertParsedScopes('repo:read repo:admin', 'repo:read', 'repo:admin')
self.assertParsedScopes('repo:read,repo:admin', 'repo:read', 'repo:admin')
self.assertParsedScopes('repo:read,repo:admin repo:write', 'repo:read', 'repo:admin',
'repo:write')
# Partially invalid scopes.
self.assertParsedScopes('repo:read,not:valid')
self.assertParsedScopes('repo:read repo:admins')
# Invalid scope strings.
self.assertParsedScopes('repo:read|repo:admin')
# Mixture of delimiters.
self.assertParsedScopes('repo:read, repo:admin')
def test_subset_string(self):
self.assertTrue(is_subset_string('repo:read', 'repo:read'))
self.assertTrue(is_subset_string('repo:read repo:admin', 'repo:read'))
self.assertTrue(is_subset_string('repo:read,repo:admin', 'repo:read'))
self.assertTrue(is_subset_string('repo:read,repo:admin', 'repo:admin'))
self.assertTrue(is_subset_string('repo:read,repo:admin', 'repo:admin repo:read'))
self.assertFalse(is_subset_string('', 'repo:read'))
self.assertFalse(is_subset_string('unknown:tag', 'repo:read'))
self.assertFalse(is_subset_string('repo:read unknown:tag', 'repo:read'))
self.assertFalse(is_subset_string('repo:read,unknown:tag', 'repo:read'))
if __name__ == '__main__':
unittest.main()

View file

@ -345,8 +345,6 @@ class OAuthTestCase(EndpointTestCase):
self.postResponse('web.authorize_application', form=form, with_csrf=False, expected_code=403)
def test_authorize_nocsrf_withinvalidheader(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
@ -358,8 +356,6 @@ class OAuthTestCase(EndpointTestCase):
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
def test_authorize_nocsrf_withbadheader(self):
self.login('devtable', 'password')
# Note: Defined in initdb.py
form = {
'client_id': 'deadbeef',
@ -368,7 +364,8 @@ class OAuthTestCase(EndpointTestCase):
}
headers = dict(authorization='Basic ' + base64.b64encode('devtable:invalidpassword'))
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=401)
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False,
expected_code=401)
def test_authorize_nocsrf_correctheader(self):
# Note: Defined in initdb.py
@ -380,13 +377,15 @@ class OAuthTestCase(EndpointTestCase):
# Try without the client id being in the whitelist.
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False, expected_code=403)
self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=False,
expected_code=403)
# Add the client ID to the whitelist and try again.
app.config['DIRECT_OAUTH_CLIENTID_WHITELIST'] = ['deadbeef']
headers = dict(authorization='Basic ' + base64.b64encode('devtable:password'))
resp = self.postResponse('web.authorize_application', headers=headers, form=form, with_csrf=True, expected_code=302)
resp = self.postResponse('web.authorize_application', headers=headers, form=form,
with_csrf=True, expected_code=302)
self.assertTrue('access_token=' in resp.headers['Location'])
def test_authorize_nocsrf_ratelimiting(self):