Add end-to-end OAuth login and attach tests

This commit is contained in:
Joseph Schorr 2016-12-08 18:35:42 -05:00
parent 36324708db
commit dbdcb802b1
4 changed files with 194 additions and 13 deletions

View file

@ -32,7 +32,10 @@ def render_ologin_error(service_name, error_message=None, register_redirect=Fals
'user_creation': user_creation, 'user_creation': user_creation,
'register_redirect': register_redirect, 'register_redirect': register_redirect,
} }
return index('', error_info=error_info)
resp = index('', error_info=error_info)
resp.status_code = 400
return resp
def get_user(service, token): def get_user(service, token):
@ -114,6 +117,9 @@ def google_oauth_callback():
code = request.args.get('code') code = request.args.get('code')
token = google_login.exchange_code_for_token(app.config, client, code, form_encode=True) token = google_login.exchange_code_for_token(app.config, client, code, form_encode=True)
if token is None:
return render_ologin_error('Google')
user_data = get_user(google_login, token) user_data = get_user(google_login, token)
if not user_data or not user_data.get('id', None) or not user_data.get('email', None): if not user_data or not user_data.get('id', None) or not user_data.get('email', None):
return render_ologin_error('Google') return render_ologin_error('Google')
@ -145,6 +151,8 @@ def github_oauth_callback():
# Exchange the OAuth code. # Exchange the OAuth code.
code = request.args.get('code') code = request.args.get('code')
token = github_login.exchange_code_for_token(app.config, client, code) token = github_login.exchange_code_for_token(app.config, client, code)
if token is None:
return render_ologin_error('GitHub')
# Retrieve the user's information. # Retrieve the user's information.
user_data = get_user(github_login, token) user_data = get_user(github_login, token)
@ -177,6 +185,8 @@ def github_oauth_callback():
# Find the e-mail address for the user: we will accept any email, but we prefer the primary # Find the e-mail address for the user: we will accept any email, but we prefer the primary
get_email = client.get(github_login.email_endpoint(), params=token_param, get_email = client.get(github_login.email_endpoint(), params=token_param,
headers=v3_media_type) headers=v3_media_type)
if get_email.status_code / 100 != 2:
return render_ologin_error('GitHub')
found_email = None found_email = None
for user_email in get_email.json(): for user_email in get_email.json():
@ -206,6 +216,8 @@ def google_oauth_attach():
code = request.args.get('code') code = request.args.get('code')
token = google_login.exchange_code_for_token(app.config, client, code, token = google_login.exchange_code_for_token(app.config, client, code,
redirect_suffix='/attach', form_encode=True) redirect_suffix='/attach', form_encode=True)
if token is None:
return render_ologin_error('Google')
user_data = get_user(google_login, token) user_data = get_user(google_login, token)
if not user_data or not user_data.get('id', None): if not user_data or not user_data.get('id', None):
@ -243,6 +255,9 @@ def google_oauth_attach():
def github_oauth_attach(): def github_oauth_attach():
code = request.args.get('code') code = request.args.get('code')
token = github_login.exchange_code_for_token(app.config, client, code) token = github_login.exchange_code_for_token(app.config, client, code)
if token is None:
return render_ologin_error('GitHub')
user_data = get_user(github_login, token) user_data = get_user(github_login, token)
if not user_data: if not user_data:
return render_ologin_error('GitHub') return render_ologin_error('GitHub')
@ -292,10 +307,12 @@ def dex_oauth_callback():
token = dex_login.exchange_code_for_token(app.config, client, code, client_auth=True, token = dex_login.exchange_code_for_token(app.config, client, code, client_auth=True,
form_encode=True) form_encode=True)
if token is None:
return render_ologin_error(dex_login.public_title)
try: try:
payload = decode_user_jwt(token, dex_login) payload = decode_user_jwt(token, dex_login)
except InvalidTokenError: except InvalidTokenError as ite:
logger.exception('Exception when decoding returned JWT') logger.exception('Exception when decoding returned JWT')
return render_ologin_error( return render_ologin_error(
dex_login.public_title, dex_login.public_title,
@ -328,7 +345,7 @@ def dex_oauth_attach():
code = request.args.get('code') code = request.args.get('code')
token = dex_login.exchange_code_for_token(app.config, client, code, redirect_suffix='/attach', token = dex_login.exchange_code_for_token(app.config, client, code, redirect_suffix='/attach',
client_auth=True, form_encode=True) client_auth=True, form_encode=True)
if not token: if token is None:
return render_ologin_error(dex_login.public_title) return render_ologin_error(dex_login.public_title)
try: try:

View file

@ -8,6 +8,7 @@ import base64
from urllib import urlencode from urllib import urlencode
from urlparse import urlparse, urlunparse, parse_qs from urlparse import urlparse, urlunparse, parse_qs
from datetime import datetime, timedelta from datetime import datetime, timedelta
from httmock import urlmatch, HTTMock
import jwt import jwt
@ -22,13 +23,25 @@ from endpoints import keyserver
from endpoints.api import api, api_bp from endpoints.api import api, api_bp
from endpoints.api.user import Signin from endpoints.api.user import Signin
from endpoints.keyserver import jwk_with_kid from endpoints.keyserver import jwk_with_kid
from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME
from endpoints.web import web as web_bp from endpoints.web import web as web_bp
from endpoints.oauthlogin import oauthlogin as oauthlogin_bp
from initdb import setup_database_for_testing, finished_database_for_testing from initdb import setup_database_for_testing, finished_database_for_testing
from test.helpers import assert_action_logged from test.helpers import assert_action_logged
try:
app.register_blueprint(oauthlogin_bp, url_prefix='/oauth')
except ValueError:
# This blueprint was already registered
pass
try: try:
app.register_blueprint(web_bp, url_prefix='') app.register_blueprint(web_bp, url_prefix='')
except ValueError:
# This blueprint was already registered
pass
try:
app.register_blueprint(keyserver.key_server, url_prefix='') app.register_blueprint(keyserver.key_server, url_prefix='')
except ValueError: except ValueError:
# This blueprint was already registered # This blueprint was already registered
@ -69,6 +82,7 @@ class EndpointTestCase(unittest.TestCase):
def setCsrfToken(self, token): def setCsrfToken(self, token):
with self.app.session_transaction() as sess: with self.app.session_transaction() as sess:
sess[CSRF_TOKEN_KEY] = token sess[CSRF_TOKEN_KEY] = token
sess[OAUTH_CSRF_TOKEN_NAME] = 'someoauthtoken'
def getResponse(self, resource_name, expected_code=200, **kwargs): def getResponse(self, resource_name, expected_code=200, **kwargs):
rv = self.app.get(url_for(resource_name, **kwargs)) rv = self.app.get(url_for(resource_name, **kwargs))
@ -108,6 +122,140 @@ class EndpointTestCase(unittest.TestCase):
self.assertEquals(rv.status_code, 200) self.assertEquals(rv.status_code, 200)
class OAuthLoginTestCase(EndpointTestCase):
def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident,
new_username):
# Test callback.
created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username)
# Delete the created user.
model.user.delete_user(created, [])
# Test attach.
self.login('devtable', 'password')
self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable')
def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username):
# No CSRF.
self.getResponse('oauthlogin.' + endpoint_name, expected_code=403)
# Invalid CSRF.
self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403)
# Valid CSRF, invalid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='invalidcode', expected_code=400)
# Valid CSRF, valid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='somecode', expected_code=302)
# Ensure the user was added/modified.
found_user = model.user.get_user(username)
self.assertIsNotNone(found_user)
federated_login = model.user.lookup_federated_login(found_user, service_name)
self.assertIsNotNone(federated_login)
self.assertEquals(federated_login.service_ident, service_ident)
return found_user
def test_google_oauth(self):
@urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token')
def account_handler(_, request):
if request.body.find("code=somecode") > 0:
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo')
def user_handler(_, __):
content = {
'id': 'someid',
'email': 'someemail@example.com',
'verified_email': True,
}
return py_json.dumps(content)
with HTTMock(account_handler, user_handler):
self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google',
'someid', 'someemail')
def test_github_oauth(self):
@urlmatch(netloc=r'github.com', path='/login/oauth/access_token')
def account_handler(url, _):
if url.query.find("code=somecode") > 0:
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'github.com', path='/api/v3/user')
def user_handler(_, __):
content = {
'id': 'someid',
'login': 'someusername'
}
return py_json.dumps(content)
@urlmatch(netloc=r'github.com', path='/api/v3/user/emails')
def email_handler(_, __):
content = [{
'email': 'someemail@example.com',
'verified': True,
'primary': True,
}]
return py_json.dumps(content)
with HTTMock(account_handler, email_handler, user_handler):
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
'someid', 'someusername')
def test_dex_oauth(self):
# TODO(jschorr): Add tests for invalid and expired keys.
# Generate a public/private key pair for the OIDC transaction.
private_key = RSA.generate(2048)
jwk = RSAKey(key=private_key.publickey()).serialize()
token = jwt.encode({
'iss': 'https://oidcserver/',
'aud': 'someclientid',
'sub': 'someid',
'exp': int(time.time()) + 60,
'iat': int(time.time()),
'nbf': int(time.time()),
'email': 'someemail@example.com',
'email_verified': True,
}, private_key.exportKey('PEM'), 'RS256')
@urlmatch(netloc=r'oidcserver', path='/.well-known/openid-configuration')
def wellknown_handler(url, _):
return py_json.dumps({
'authorization_endpoint': 'http://oidcserver/auth',
'token_endpoint': 'http://oidcserver/token',
'jwks_uri': 'http://oidcserver/keys',
})
@urlmatch(netloc=r'oidcserver', path='/token')
def account_handler(url, request):
if request.body.find("code=somecode") > 0:
return py_json.dumps({
'access_token': token,
})
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'oidcserver', path='/keys')
def keys_handler(_, __):
return py_json.dumps({
"keys": [jwk],
})
with HTTMock(wellknown_handler, account_handler, keys_handler):
self.invoke_oauth_tests('dex_oauth_callback', 'dex_oauth_attach', 'dex',
'someid', 'someemail')
class WebEndpointTestCase(EndpointTestCase): class WebEndpointTestCase(EndpointTestCase):
def test_index(self): def test_index(self):
self.getResponse('web.index') self.getResponse('web.index')

View file

@ -75,3 +75,12 @@ class TestConfig(DefaultConfig):
INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem' INSTANCE_SERVICE_KEY_LOCATION = 'test/data/test.pem'
PROMETHEUS_AGGREGATOR_URL = None PROMETHEUS_AGGREGATOR_URL = None
FEATURE_GITHUB_LOGIN = True
FEATURE_GOOGLE_LOGIN = True
FEATURE_DEX_LOGIN = True
DEX_LOGIN_CONFIG = {
'CLIENT_ID': 'someclientid',
'OIDC_SERVER': 'https://oidcserver/',
}

View file

@ -4,6 +4,7 @@ import logging
import time import time
from cachetools import TTLCache from cachetools import TTLCache
from cachetools.func import lru_cache
from jwkest.jwk import KEYS from jwkest.jwk import KEYS
from util import slash_join from util import slash_join
@ -64,12 +65,14 @@ class OAuthConfig(object):
else: else:
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth) get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
if get_access_token.status_code / 100 != 2:
return None
json_data = get_access_token.json() json_data = get_access_token.json()
if not json_data: if not json_data:
return '' return None
token = json_data.get('access_token', '') return json_data.get('access_token', None)
return token
class GithubOAuthConfig(OAuthConfig): class GithubOAuthConfig(OAuthConfig):
@ -265,11 +268,15 @@ class OIDCConfig(OAuthConfig):
super(OIDCConfig, self).__init__(config, key_name) super(OIDCConfig, self).__init__(config, key_name)
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key) self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
self._oidc_config = {} self._config = config
self._http_client = config['HTTPCLIENT'] self._http_client = config['HTTPCLIENT']
@lru_cache(maxsize=1)
def _oidc_config(self):
if self.config.get('OIDC_SERVER'): if self.config.get('OIDC_SERVER'):
self._load_via_discovery(config.get('DEBUGGING', False)) return self._load_via_discovery(self._config.get('DEBUGGING', False))
else:
return {}
def _load_via_discovery(self, is_debugging): def _load_via_discovery(self, is_debugging):
oidc_server = self.config['OIDC_SERVER'] oidc_server = self.config['OIDC_SERVER']
@ -283,16 +290,16 @@ class OIDCConfig(OAuthConfig):
raise Exception("Could not load OIDC discovery information") raise Exception("Could not load OIDC discovery information")
try: try:
self._oidc_config = json.loads(discovery.text) return json.loads(discovery.text)
except ValueError: except ValueError:
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
raise Exception("Could not parse OIDC discovery information") raise Exception("Could not parse OIDC discovery information")
def authorize_endpoint(self): def authorize_endpoint(self):
return self._oidc_config.get('authorization_endpoint', '') + '?' return self._oidc_config().get('authorization_endpoint', '') + '?'
def token_endpoint(self): def token_endpoint(self):
return self._oidc_config.get('token_endpoint') return self._oidc_config().get('token_endpoint')
def user_endpoint(self): def user_endpoint(self):
return None return None
@ -322,9 +329,9 @@ class OIDCConfig(OAuthConfig):
# a random key chose to be stored in the cache, and could be anything. # a random key chose to be stored in the cache, and could be anything.
return self._public_key_cache[None] return self._public_key_cache[None]
def _get_public_key(self): def _get_public_key(self, _):
""" Retrieves the public key for this handler. """ """ Retrieves the public key for this handler. """
keys_url = self._oidc_config['jwks_uri'] keys_url = self._oidc_config()['jwks_uri']
keys = KEYS() keys = KEYS()
keys.load_from_url(keys_url) keys.load_from_url(keys_url)