Add proper and tested OIDC support on the server

Note that this will still not work on the client side; the followup CL for the client side is right after this one.
This commit is contained in:
Joseph Schorr 2017-01-23 17:53:34 -05:00
parent 19f7acf575
commit fda203e4d7
15 changed files with 756 additions and 180 deletions

View file

@ -368,7 +368,7 @@ def update_user_metadata(user, given_name=None, family_name=None, company=None):
remove_user_prompt(user, UserPromptTypes.ENTER_COMPANY) remove_user_prompt(user, UserPromptTypes.ENTER_COMPANY)
def create_federated_user(username, email, service_name, service_ident, def create_federated_user(username, email, service_id, service_ident,
set_password_notification, metadata={}, set_password_notification, metadata={},
email_required=True, prompts=tuple()): email_required=True, prompts=tuple()):
prompts = set(prompts) prompts = set(prompts)
@ -378,7 +378,11 @@ def create_federated_user(username, email, service_name, service_ident,
new_user.verified = True new_user.verified = True
new_user.save() new_user.save()
service = LoginService.get(LoginService.name == service_name) try:
service = LoginService.get(LoginService.name == service_id)
except LoginService.DoesNotExist:
service = LoginService.create(name=service_id)
FederatedLogin.create(user=new_user, service=service, FederatedLogin.create(user=new_user, service=service,
service_ident=service_ident, service_ident=service_ident,
metadata_json=json.dumps(metadata)) metadata_json=json.dumps(metadata))
@ -389,20 +393,20 @@ def create_federated_user(username, email, service_name, service_ident,
return new_user return new_user
def attach_federated_login(user, service_name, service_ident, metadata={}): def attach_federated_login(user, service_id, service_ident, metadata={}):
service = LoginService.get(LoginService.name == service_name) service = LoginService.get(LoginService.name == service_id)
FederatedLogin.create(user=user, service=service, service_ident=service_ident, FederatedLogin.create(user=user, service=service, service_ident=service_ident,
metadata_json=json.dumps(metadata)) metadata_json=json.dumps(metadata))
return user return user
def verify_federated_login(service_name, service_ident): def verify_federated_login(service_id, service_ident):
try: try:
found = (FederatedLogin found = (FederatedLogin
.select(FederatedLogin, User) .select(FederatedLogin, User)
.join(LoginService) .join(LoginService)
.switch(FederatedLogin).join(User) .switch(FederatedLogin).join(User)
.where(FederatedLogin.service_ident == service_ident, LoginService.name == service_name) .where(FederatedLogin.service_ident == service_ident, LoginService.name == service_id)
.get()) .get())
return found.user return found.user
except FederatedLogin.DoesNotExist: except FederatedLogin.DoesNotExist:

View file

@ -197,7 +197,6 @@ def render_page_template(name, route_data=None, **kwargs):
'title': login_service.service_name(), 'title': login_service.service_name(),
'config': login_service.get_public_config(), 'config': login_service.get_public_config(),
'icon': login_service.get_icon(), 'icon': login_service.get_icon(),
'scopes': login_service.get_login_scopes(),
}) })
return login_config return login_config

View file

@ -1,4 +1,5 @@
import logging import logging
import uuid
from flask import request, redirect, url_for, Blueprint from flask import request, redirect, url_for, Blueprint
from peewee import IntegrityError from peewee import IntegrityError
@ -50,6 +51,7 @@ def _conduct_oauth_login(service_id, service_name, user_id, username, email, met
# Try to create the user # Try to create the user
try: try:
# Generate a valid username.
new_username = None new_username = None
for valid in generate_valid_usernames(username): for valid in generate_valid_usernames(username):
if model.user.get_user_or_org(valid): if model.user.get_user_or_org(valid):
@ -58,6 +60,11 @@ def _conduct_oauth_login(service_id, service_name, user_id, username, email, met
new_username = valid new_username = valid
break break
# Generate a valid email. If the email is None and the MAILING feature is turned
# off, simply place in a fake email address.
if email is None and not features.MAILING:
email = '%s@fake.example.com' % (str(uuid.uuid4()))
prompts = model.user.get_default_user_prompts(features) prompts = model.user.get_default_user_prompts(features)
to_login = model.user.create_federated_user(new_username, email, service_id, to_login = model.user.create_federated_user(new_username, email, service_id,
user_id, set_password_notification=True, user_id, set_password_notification=True,
@ -102,6 +109,7 @@ def _register_service(login_service):
try: try:
lid, lusername, lemail = login_service.exchange_code_for_login(app.config, client, code, '') lid, lusername, lemail = login_service.exchange_code_for_login(app.config, client, code, '')
except OAuthLoginException as ole: except OAuthLoginException as ole:
logger.exception('Got login exception')
return _render_ologin_error(login_service.service_name(), ole.message) return _render_ologin_error(login_service.service_name(), ole.message)
# Conduct login. # Conduct login.

View file

@ -22,7 +22,7 @@ EXTERNAL_JS = [
] ]
EXTERNAL_CSS = [ EXTERNAL_CSS = [
'netdna.bootstrapcdn.com/font-awesome/4.6.0/css/font-awesome.css', 'netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css',
'netdna.bootstrapcdn.com/bootstrap/3.3.2/css/bootstrap.min.css', 'netdna.bootstrapcdn.com/bootstrap/3.3.2/css/bootstrap.min.css',
'fonts.googleapis.com/css?family=Source+Sans+Pro:300,400,700', 'fonts.googleapis.com/css?family=Source+Sans+Pro:300,400,700',
's3.amazonaws.com/cdn.core-os.net/icons/core-icons.css', 's3.amazonaws.com/cdn.core-os.net/icons/core-icons.css',

View file

@ -51,7 +51,7 @@ class OAuthService(object):
def get_redirect_uri(self, app_config, redirect_suffix=''): def get_redirect_uri(self, app_config, redirect_suffix=''):
return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'], return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
app_config['SERVER_HOSTNAME'], app_config['SERVER_HOSTNAME'],
self.service_name().lower(), self.service_id(),
redirect_suffix) redirect_suffix)
def get_user_info(self, http_client, token): def get_user_info(self, http_client, token):
@ -74,7 +74,7 @@ class OAuthService(object):
def exchange_code_for_token(self, app_config, http_client, code, form_encode=False, def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
redirect_suffix='', client_auth=False): redirect_suffix='', client_auth=False):
""" Exchanges an OAuth access code for the associated OAuth token. """ """ Exchanges an OAuth access code for the associated OAuth token. """
json_data = self._exchange_code(app_config, http_client, code, form_encode, redirect_suffix, json_data = self.exchange_code(app_config, http_client, code, form_encode, redirect_suffix,
client_auth) client_auth)
access_token = json_data.get('access_token', None) access_token = json_data.get('access_token', None)
@ -84,8 +84,9 @@ class OAuthService(object):
return access_token return access_token
def _exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='', def exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='',
client_auth=False): client_auth=False):
""" Exchanges an OAuth access code for associated OAuth token and other data. """
payload = { payload = {
'code': code, 'code': code,
'grant_type': 'authorization_code', 'grant_type': 'authorization_code',

View file

@ -14,6 +14,10 @@ class OAuthLoginService(OAuthService):
""" A base class for defining an OAuth-compliant service that can be used for, amongst other """ A base class for defining an OAuth-compliant service that can be used for, amongst other
things, login and authentication. """ things, login and authentication. """
def login_enabled(self):
""" Returns true if the login service is enabled. """
raise NotImplementedError
def get_login_service_id(self, user_info): def get_login_service_id(self, user_info):
""" Returns the internal ID for the given user under this login service. """ """ Returns the internal ID for the given user under this login service. """
raise NotImplementedError raise NotImplementedError

View file

@ -1,7 +1,11 @@
import features
from oauth.services.github import GithubOAuthService from oauth.services.github import GithubOAuthService
from oauth.services.google import GoogleOAuthService from oauth.services.google import GoogleOAuthService
from oauth.oidc import OIDCLoginService
CUSTOM_LOGIN_SERVICES = {
'GITHUB_LOGIN_CONFIG': GithubOAuthService,
'GOOGLE_LOGIN_CONFIG': GoogleOAuthService,
}
class OAuthLoginManager(object): class OAuthLoginManager(object):
""" Helper class which manages all registered OAuth login services. """ """ Helper class which manages all registered OAuth login services. """
@ -9,11 +13,12 @@ class OAuthLoginManager(object):
self.services = [] self.services = []
# Register the endpoints for each of the OAuth login services. # Register the endpoints for each of the OAuth login services.
# TODO(jschorr): make this dynamic. for key in config.keys():
if config.get('GITHUB_LOGIN_CONFIG') is not None and features.GITHUB_LOGIN: # All keys which end in _LOGIN_CONFIG setup a login service.
github_service = GithubOAuthService(config, 'GITHUB_LOGIN_CONFIG') if key.endswith('_LOGIN_CONFIG'):
self.services.append(github_service) if key in CUSTOM_LOGIN_SERVICES:
custom_service = CUSTOM_LOGIN_SERVICES[key](config, key)
if config.get('GOOGLE_LOGIN_CONFIG') is not None and features.GOOGLE_LOGIN: if custom_service.login_enabled(config):
google_service = GoogleOAuthService(config, 'GOOGLE_LOGIN_CONFIG') self.services.append(custom_service)
self.services.append(google_service) else:
self.services.append(OIDCLoginService(config, key))

View file

@ -3,108 +3,244 @@ import json
import logging import logging
import urlparse import urlparse
import jwt
from cachetools import lru_cache from cachetools import lru_cache
from cachetools.ttl import TTLCache from cachetools.ttl import TTLCache
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key
from jwkest.jwk import KEYS
from util.oauth.base import OAuthService from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException
from oauth.login import OAuthLoginException
from util.security.jwtutil import decode, InvalidTokenError
from util import get_app_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def decode_user_jwt(token, oidc_provider):
try:
return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'],
audience=oidc_provider.client_id(),
issuer=oidc_provider.issuer)
except InvalidTokenError:
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'],
audience=oidc_provider.client_id(),
issuer=oidc_provider.issuer)
OIDC_WELLKNOWN = ".well-known/openid-configuration" OIDC_WELLKNOWN = ".well-known/openid-configuration"
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
ALLOWED_ALGORITHMS = ['RS256']
JWT_CLOCK_SKEW_SECONDS = 30
class OIDCConfig(OAuthService): class DiscoveryFailureException(Exception):
""" Exception raised when OIDC discovery fails. """
pass
class PublicKeyLoadException(Exception):
""" Exception raised if loading the OIDC public key fails. """
pass
class OIDCLoginService(OAuthService):
""" Defines a generic service for all OpenID-connect compatible login services. """
def __init__(self, config, key_name): def __init__(self, config, key_name):
super(OIDCConfig, self).__init__(config, key_name) super(OIDCLoginService, self).__init__(config, key_name)
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key) self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._load_public_key)
self._config = config self._id = key_name[0:key_name.find('_')].lower()
self._http_client = config['HTTPCLIENT'] self._http_client = config['HTTPCLIENT']
self._mailing = config.get('FEATURE_MAILING', False)
@lru_cache(maxsize=1) def service_id(self):
def _oidc_config(self): return self._id
if self.config.get('OIDC_SERVER'):
return self._load_via_discovery(self._config.get('DEBUGGING', False))
else:
return {}
def _load_via_discovery(self, is_debugging): def service_name(self):
oidc_server = self.config['OIDC_SERVER'] return self.config.get('SERVICE_NAME', self.service_id())
if not oidc_server.startswith('https://') and not is_debugging:
raise Exception('OIDC server must be accessed over SSL')
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN) def get_icon(self):
discovery = self._http_client.get(discovery_url, timeout=5) return self.config.get('SERVICE_ICON', 'fa-user-circle')
if discovery.status_code / 100 != 2: def get_login_scopes(self):
raise Exception("Could not load OIDC discovery information") default_scopes = ['openid']
try: if self.user_endpoint() is not None:
return json.loads(discovery.text) default_scopes.append('profile')
except ValueError:
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) if self._mailing:
raise Exception("Could not parse OIDC discovery information") default_scopes.append('email')
return self._oidc_config().get('scopes_supported', default_scopes)
def authorize_endpoint(self): def authorize_endpoint(self):
return self._oidc_config().get('authorization_endpoint', '') + '?' return self._oidc_config().get('authorization_endpoint', '') + '?response_type=code&'
def token_endpoint(self): def token_endpoint(self):
return self._oidc_config().get('token_endpoint') return self._oidc_config().get('token_endpoint')
def user_endpoint(self): def user_endpoint(self):
return None return self._oidc_config().get('userinfo_endpoint')
def validate_client_id_and_secret(self, http_client, app_config): def validate_client_id_and_secret(self, http_client, app_config):
pass # TODO: find a way to verify client secret too.
redirect_url = '%s/oauth2/%s/callback' % (get_app_url(app_config), self.service_id())
scopes_string = ' '.join(self.get_login_scopes())
authorize_url = '%sclient_id=%s&redirect_uri=%s&scope=%s' % (self.authorize_endpoint(),
self.client_id(),
redirect_url,
scopes_string)
check_auth_url = http_client.get(authorize_url)
if check_auth_url.status_code // 100 != 2:
raise Exception('Got non-200 status code for authorization endpoint')
def requires_form_encoding(self):
return True
def get_public_config(self): def get_public_config(self):
return { return {
'CLIENT_ID': self.client_id(), 'CLIENT_ID': self.client_id(),
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
'OIDC': True, 'OIDC': True,
} }
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
# Exchange the code for the access token and id_token
try:
json_data = self.exchange_code(app_config, http_client, code,
redirect_suffix=redirect_suffix,
form_encode=self.requires_form_encoding())
except OAuthExchangeCodeException as oce:
raise OAuthLoginException(oce.message)
# Make sure we received both.
access_token = json_data.get('access_token', None)
if access_token is None:
logger.debug('Missing access_token in response: %s', json_data)
raise OAuthLoginException('Missing `access_token` in OIDC response')
id_token = json_data.get('id_token', None)
if id_token is None:
logger.debug('Missing id_token in response: %s', json_data)
raise OAuthLoginException('Missing `id_token` in OIDC response')
# Decode the id_token.
try:
decoded_id_token = self._decode_user_jwt(id_token)
except InvalidTokenError as ite:
logger.exception('Got invalid token error on OIDC decode: %s', ite.message)
raise OAuthLoginException('Could not decode OIDC token')
except PublicKeyLoadException as pke:
logger.exception('Could not load public key during OIDC decode: %s', pke.message)
raise OAuthLoginException('Could find public OIDC key')
# Retrieve the user information.
try:
user_info = self.get_user_info(http_client, access_token)
except OAuthGetUserInfoException as oge:
raise OAuthLoginException(oge.message)
# Verify subs.
if user_info['sub'] != decoded_id_token['sub']:
raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint')
# Check if we have a verified email address.
email_address = user_info.get('email') if user_info.get('email_verified') else None
if self._mailing:
if email_address is None:
raise OAuthLoginException('A verified email address is required to login with this service')
# Check for a preferred username.
lusername = user_info.get('preferred_username') or user_info.get('sub')
return decoded_id_token['sub'], lusername, email_address
@property @property
def issuer(self): def _issuer(self):
return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER']) return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
def get_public_key(self, force_refresh=False): @lru_cache(maxsize=1)
""" Retrieves the public key for this handler. """ def _oidc_config(self):
if self.config.get('OIDC_SERVER'):
return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False))
else:
return {}
def _load_oidc_config_via_discovery(self, is_debugging):
""" Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True,
non-secure connections are alllowed. Raises an DiscoveryFailureException on failure.
"""
oidc_server = self.config['OIDC_SERVER']
if not oidc_server.startswith('https://') and not is_debugging:
raise DiscoveryFailureException('OIDC server must be accessed over SSL')
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
discovery = self._http_client.get(discovery_url, timeout=5, verify=not is_debugging)
if discovery.status_code // 100 != 2:
logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text)
raise DiscoveryFailureException("Could not load OIDC discovery information")
try:
return json.loads(discovery.text)
except ValueError:
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
raise DiscoveryFailureException("Could not parse OIDC discovery information")
def _decode_user_jwt(self, token):
""" Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError
exception on an invalid token or a PublicKeyLoadException if the public key could not be
loaded for decoding.
"""
# Find the key to use.
headers = jwt.get_unverified_header(token)
kid = headers.get('kid', None)
if kid is None:
raise InvalidTokenError('Missing `kid` header')
try:
return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False))
except InvalidTokenError:
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
return decode(token, self._get_public_key(kid, force_refresh=True),
algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False))
def _get_public_key(self, kid, force_refresh=False):
""" Retrieves the public key for this handler with the given kid. Raises a
PublicKeyLoadException on failure. """
# If force_refresh is true, we expire all the items in the cache by setting the time to # If force_refresh is true, we expire all the items in the cache by setting the time to
# the current time + the expiration TTL. # the current time + the expiration TTL.
if force_refresh: if force_refresh:
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL) self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
# Retrieve the public key from the cache. If the cache does not contain the public key, # Retrieve the public key from the cache. If the cache does not contain the public key,
# it will internally call _get_public_key to retrieve it and then save it. The None is # it will internally call _load_public_key to retrieve it and then save it.
# a random key chose to be stored in the cache, and could be anything. return self._public_key_cache[kid]
return self._public_key_cache[None]
def _get_public_key(self, _): def _load_public_key(self, kid):
""" Retrieves the public key for this handler. """ """ Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException
on failure.
"""
keys_url = self._oidc_config()['jwks_uri'] keys_url = self._oidc_config()['jwks_uri']
# Load the keys.
try:
keys = KEYS() keys = KEYS()
keys.load_from_url(keys_url) keys.load_from_url(keys_url, verify=not self.config.get('DEBUGGING', False))
except Exception as ex:
logger.exception('Exception loading public key')
raise PublicKeyLoadException(ex.message)
if not list(keys): # Find the matching key.
raise Exception('No keys provided by OIDC provider') keys_found = keys.by_kid(kid)
if len(keys_found) == 0:
raise PublicKeyLoadException('Public key %s not found' % kid)
rsa_key = list(keys)[0] rsa_keys = [key for key in keys_found if key.kty == 'RSA']
rsa_key.deserialize() if len(rsa_keys) == 0:
raise PublicKeyLoadException('No RSA form of public key %s not found' % kid)
matching_key = rsa_keys[0]
matching_key.deserialize()
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing # Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
# issues. # issues.
return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend()) return load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend())

View file

@ -9,6 +9,9 @@ class GithubOAuthService(OAuthLoginService):
def __init__(self, config, key_name): def __init__(self, config, key_name):
super(GithubOAuthService, self).__init__(config, key_name) super(GithubOAuthService, self).__init__(config, key_name)
def login_enabled(self, config):
return config.get('FEATURE_GITHUB_LOGIN', False)
def service_id(self): def service_id(self):
return 'github' return 'github'

View file

@ -12,6 +12,9 @@ class GoogleOAuthService(OAuthLoginService):
def __init__(self, config, key_name): def __init__(self, config, key_name):
super(GoogleOAuthService, self).__init__(config, key_name) super(GoogleOAuthService, self).__init__(config, key_name)
def login_enabled(self, config):
return config.get('FEATURE_GOOGLE_LOGIN', False)
def service_id(self): def service_id(self):
return 'google' return 'google'

View file

@ -0,0 +1,62 @@
from oauth.loginmanager import OAuthLoginManager
from oauth.services.github import GithubOAuthService
from oauth.services.google import GoogleOAuthService
from oauth.oidc import OIDCLoginService
def test_login_manager_github():
config = {
'FEATURE_GITHUB_LOGIN': True,
'GITHUB_LOGIN_CONFIG': {},
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 1
assert isinstance(loginmanager.services[0], GithubOAuthService)
def test_github_disabled():
config = {
'GITHUB_LOGIN_CONFIG': {},
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 0
def test_login_manager_google():
config = {
'FEATURE_GOOGLE_LOGIN': True,
'GOOGLE_LOGIN_CONFIG': {},
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 1
assert isinstance(loginmanager.services[0], GoogleOAuthService)
def test_google_disabled():
config = {
'GOOGLE_LOGIN_CONFIG': {},
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 0
def test_oidc():
config = {
'SOMECOOL_LOGIN_CONFIG': {},
'HTTPCLIENT': None,
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 1
assert isinstance(loginmanager.services[0], OIDCLoginService)
def test_multiple_oidc():
config = {
'SOMECOOL_LOGIN_CONFIG': {},
'ANOTHER_LOGIN_CONFIG': {},
'HTTPCLIENT': None,
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 2
assert isinstance(loginmanager.services[0], OIDCLoginService)
assert isinstance(loginmanager.services[1], OIDCLoginService)

273
oauth/test/test_oidc.py Normal file
View file

@ -0,0 +1,273 @@
# pylint: disable=redefined-outer-name, unused-argument, C0103, C0111, too-many-arguments
import json
import time
import urlparse
import jwt
import pytest
import requests
from httmock import urlmatch, HTTMock
from Crypto.PublicKey import RSA
from jwkest.jwk import RSAKey
from oauth.oidc import OIDCLoginService, OAuthLoginException
@pytest.fixture()
def http_client():
sess = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections=100,
pool_maxsize=100)
sess.mount('http://', adapter)
sess.mount('https://', adapter)
return sess
@pytest.fixture(params=[True, False])
def app_config(http_client, request):
return {
'PREFERRED_URL_SCHEME': 'http',
'SERVER_HOSTNAME': 'localhost',
'FEATURE_MAILING': request.param,
'SOMEOIDC_TEST_SERVICE': {
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
'SERVICE_NAME': 'Some Cool Service',
'SERVICE_ICON': 'http://some/icon',
'OIDC_SERVER': 'http://fakeoidc',
'DEBUGGING': True,
},
'HTTPCLIENT': http_client,
}
@pytest.fixture()
def oidc_service(app_config):
return OIDCLoginService(app_config, 'SOMEOIDC_TEST_SERVICE')
@pytest.fixture()
def discovery_content():
return {
'scopes_supported': ['profile'],
'authorization_endpoint': 'http://fakeoidc/authorize',
'token_endpoint': 'http://fakeoidc/token',
'userinfo_endpoint': 'http://fakeoidc/userinfo',
'jwks_uri': 'http://fakeoidc/jwks',
}
@pytest.fixture()
def discovery_handler(discovery_content):
@urlmatch(netloc=r'fakeoidc', path=r'.+openid.+')
def handler(_, __):
return json.dumps(discovery_content)
return handler
@pytest.fixture(scope="module") # Slow to generate, only do it once.
def signing_key():
private_key = RSA.generate(2048)
jwk = RSAKey(key=private_key.publickey()).serialize()
return {
'id': 'somekey',
'private_key': private_key.exportKey('PEM'),
'jwk': jwk,
}
@pytest.fixture()
def id_token(oidc_service, signing_key, app_config):
token_data = {
'iss': oidc_service.config['OIDC_SERVER'],
'aud': oidc_service.client_id(),
'nbf': int(time.time()),
'iat': int(time.time()),
'exp': int(time.time() + 600),
'sub': 'cooluser',
}
token_headers = {
'kid': signing_key['id'],
}
return jwt.encode(token_data, signing_key['private_key'], 'RS256', headers=token_headers)
@pytest.fixture()
def valid_code():
return 'validcode'
@pytest.fixture()
def token_handler(oidc_service, id_token, valid_code):
@urlmatch(netloc=r'fakeoidc', path=r'/token')
def handler(_, request):
params = urlparse.parse_qs(request.body)
if params.get('redirect_uri')[0] != 'http://localhost/oauth2/someoidc/callback':
return {'status_code': 400, 'content': 'Invalid redirect URI'}
if params.get('client_id')[0] != oidc_service.client_id():
return {'status_code': 401, 'content': 'Invalid client id'}
if params.get('client_secret')[0] != oidc_service.client_secret():
return {'status_code': 401, 'content': 'Invalid client secret'}
if params.get('code')[0] != valid_code:
return {'status_code': 401, 'content': 'Invalid code'}
if params.get('grant_type')[0] != 'authorization_code':
return {'status_code': 400, 'content': 'Invalid authorization type'}
content = {
'access_token': 'sometoken',
'id_token': id_token,
}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
@pytest.fixture()
def jwks_handler(signing_key):
def jwk_with_kid(kid, jwk):
jwk = jwk.copy()
jwk.update({'kid': kid})
return jwk
@urlmatch(netloc=r'fakeoidc', path=r'/jwks')
def handler(_, __):
content = {'keys': [jwk_with_kid(signing_key['id'], signing_key['jwk'])]}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
@pytest.fixture()
def emptykeys_jwks_handler():
@urlmatch(netloc=r'fakeoidc', path=r'/jwks')
def handler(_, __):
content = {'keys': []}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
@pytest.fixture(params=["someusername", None])
def preferred_username(request):
return request.param
@pytest.fixture
def userinfo_handler(oidc_service, preferred_username):
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
def handler(_, __):
content = {
'sub': 'cooluser',
'preferred_username':preferred_username,
'email': 'foo@example.com',
'email_verified': True,
}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
@pytest.fixture()
def invalidsub_userinfo_handler(oidc_service):
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
def handler(_, __):
content = {
'sub': 'invalidsub',
'preferred_username': 'someusername',
'email': 'foo@example.com',
'email_verified': True,
}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
@pytest.fixture()
def missingemail_userinfo_handler(oidc_service, preferred_username):
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
def handler(_, __):
content = {
'sub': 'cooluser',
'preferred_username': preferred_username,
}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
def test_basic_config(oidc_service):
assert oidc_service.service_id() == 'someoidc'
assert oidc_service.service_name() == 'Some Cool Service'
assert oidc_service.get_icon() == 'http://some/icon'
def test_discovery(oidc_service, http_client, discovery_handler):
with HTTMock(discovery_handler):
assert oidc_service.authorize_endpoint() == 'http://fakeoidc/authorize?response_type=code&'
assert oidc_service.token_endpoint() == 'http://fakeoidc/token'
assert oidc_service.user_endpoint() == 'http://fakeoidc/userinfo'
assert oidc_service.get_login_scopes() == ['profile']
def test_public_config(oidc_service, discovery_handler):
with HTTMock(discovery_handler):
assert oidc_service.get_public_config()['OIDC']
assert oidc_service.get_public_config()['CLIENT_ID'] == 'foo'
assert 'CLIENT_SECRET' not in oidc_service.get_public_config()
assert 'bar' not in oidc_service.get_public_config().values()
def test_exchange_code_invalidcode(oidc_service, discovery_handler, app_config, http_client,
token_handler):
with HTTMock(token_handler, discovery_handler):
with pytest.raises(OAuthLoginException):
oidc_service.exchange_code_for_login(app_config, http_client, 'testcode', '')
def test_exchange_code_validcode(oidc_service, discovery_handler, app_config, http_client,
token_handler, userinfo_handler, jwks_handler, valid_code,
preferred_username):
with HTTMock(jwks_handler, token_handler, userinfo_handler, discovery_handler):
lid, lusername, lemail = oidc_service.exchange_code_for_login(app_config, http_client,
valid_code, '')
assert lid == 'cooluser'
assert lemail == 'foo@example.com'
if preferred_username is not None:
assert lusername == preferred_username
else:
assert lusername == lid
def test_exchange_code_missingemail(oidc_service, discovery_handler, app_config, http_client,
token_handler, missingemail_userinfo_handler, jwks_handler,
valid_code, preferred_username):
with HTTMock(jwks_handler, token_handler, missingemail_userinfo_handler, discovery_handler):
if app_config['FEATURE_MAILING']:
# Should fail because there is no valid email address.
with pytest.raises(OAuthLoginException):
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
else:
# Should succeed because, while there is no valid email address, it isn't necessary with
# mailing disabled.
lid, lusername, lemail = oidc_service.exchange_code_for_login(app_config, http_client,
valid_code, '')
assert lid == 'cooluser'
assert lemail is None
if preferred_username is not None:
assert lusername == preferred_username
else:
assert lusername == lid
def test_exchange_code_invalidsub(oidc_service, discovery_handler, app_config, http_client,
token_handler, invalidsub_userinfo_handler, jwks_handler,
valid_code):
with HTTMock(jwks_handler, token_handler, invalidsub_userinfo_handler, discovery_handler):
# Should fail because the sub of the user info doesn't match that returned by the id_token.
with pytest.raises(OAuthLoginException):
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
def test_exchange_code_missingkey(oidc_service, discovery_handler, app_config, http_client,
token_handler, userinfo_handler, emptykeys_jwks_handler,
valid_code):
with HTTMock(emptykeys_jwks_handler, token_handler, userinfo_handler, discovery_handler):
# Should fail because the key is missing.
with pytest.raises(OAuthLoginException):
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')

View file

@ -8,7 +8,6 @@ import base64
from urllib import urlencode from urllib import urlencode
from urlparse import urlparse, urlunparse, parse_qs from urlparse import urlparse, urlunparse, parse_qs
from datetime import datetime, timedelta from datetime import datetime, timedelta
from httmock import urlmatch, HTTMock
import jwt import jwt
@ -16,7 +15,7 @@ from Crypto.PublicKey import RSA
from flask import url_for from flask import url_for
from jwkest.jwk import RSAKey from jwkest.jwk import RSAKey
from app import app, oauth_login from app import app
from data import model from data import model
from data.database import ServiceKeyApprovalType from data.database import ServiceKeyApprovalType
from endpoints import keyserver from endpoints import keyserver
@ -25,16 +24,9 @@ from endpoints.api.user import Signin
from endpoints.keyserver import jwk_with_kid from endpoints.keyserver import jwk_with_kid
from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME
from endpoints.web import web as web_bp from endpoints.web import web as web_bp
from endpoints.oauthlogin import oauthlogin as oauthlogin_bp
from initdb import setup_database_for_testing, finished_database_for_testing from initdb import setup_database_for_testing, finished_database_for_testing
from test.helpers import assert_action_logged from test.helpers import assert_action_logged
try:
app.register_blueprint(oauthlogin_bp, url_prefix='/oauth2')
except ValueError:
# This blueprint was already registered
pass
try: try:
app.register_blueprint(web_bp, url_prefix='') app.register_blueprint(web_bp, url_prefix='')
except ValueError: except ValueError:
@ -129,96 +121,6 @@ class EndpointTestCase(unittest.TestCase):
self.assertEquals(rv.status_code, 200) self.assertEquals(rv.status_code, 200)
class OAuthLoginTestCase(EndpointTestCase):
def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident,
new_username):
# Test callback.
created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username)
# Delete the created user.
model.user.delete_user(created, [])
# Test attach.
self.login('devtable', 'password')
self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable')
def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username):
# No CSRF.
self.getResponse('oauthlogin.' + endpoint_name, expected_code=403)
# Invalid CSRF.
self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403)
# Valid CSRF, invalid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='invalidcode', expected_code=400)
# Valid CSRF, valid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='somecode', expected_code=302)
# Ensure the user was added/modified.
found_user = model.user.get_user(username)
self.assertIsNotNone(found_user)
federated_login = model.user.lookup_federated_login(found_user, service_name)
self.assertIsNotNone(federated_login)
self.assertEquals(federated_login.service_ident, service_ident)
return found_user
def test_google_oauth(self):
@urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token')
def account_handler(_, request):
if request.body.find("code=somecode") > 0:
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo')
def user_handler(_, __):
content = {
'id': 'someid',
'email': 'someemail@example.com',
'verified_email': True,
}
return py_json.dumps(content)
with HTTMock(account_handler, user_handler):
self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google',
'someid', 'someemail')
def test_github_oauth(self):
@urlmatch(netloc=r'github.com', path='/login/oauth/access_token')
def account_handler(url, _):
if url.query.find("code=somecode") > 0:
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'github.com', path='/api/v3/user')
def user_handler(_, __):
content = {
'id': 'someid',
'login': 'someusername'
}
return py_json.dumps(content)
@urlmatch(netloc=r'github.com', path='/api/v3/user/emails')
def email_handler(_, __):
content = [{
'email': 'someemail@example.com',
'verified': True,
'primary': True,
}]
return py_json.dumps(content)
with HTTMock(account_handler, email_handler, user_handler):
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
'someid', 'someusername')
class WebEndpointTestCase(EndpointTestCase): class WebEndpointTestCase(EndpointTestCase):
def test_index(self): def test_index(self):
self.getResponse('web.index') self.getResponse('web.index')

175
test/test_oauth_login.py Normal file
View file

@ -0,0 +1,175 @@
import json as py_json
import time
import unittest
import jwt
from Crypto.PublicKey import RSA
from httmock import urlmatch, HTTMock
from jwkest.jwk import RSAKey
from app import app
from data import model
from endpoints.oauthlogin import oauthlogin as oauthlogin_bp
from test.test_endpoints import EndpointTestCase
try:
app.register_blueprint(oauthlogin_bp, url_prefix='/oauth2')
except ValueError:
# This blueprint was already registered
pass
class OAuthLoginTestCase(EndpointTestCase):
def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident,
new_username):
# Test callback.
created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username)
# Delete the created user.
model.user.delete_user(created, [])
# Test attach.
self.login('devtable', 'password')
self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable')
def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username):
# No CSRF.
self.getResponse('oauthlogin.' + endpoint_name, expected_code=403)
# Invalid CSRF.
self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403)
# Valid CSRF, invalid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='invalidcode', expected_code=400)
# Valid CSRF, valid code.
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
code='somecode', expected_code=302)
# Ensure the user was added/modified.
found_user = model.user.get_user(username)
self.assertIsNotNone(found_user)
federated_login = model.user.lookup_federated_login(found_user, service_name)
self.assertIsNotNone(federated_login)
self.assertEquals(federated_login.service_ident, service_ident)
return found_user
def test_google_oauth(self):
@urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token')
def account_handler(_, request):
if request.body.find("code=somecode") > 0:
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo')
def user_handler(_, __):
content = {
'id': 'someid',
'email': 'someemail@example.com',
'verified_email': True,
}
return py_json.dumps(content)
with HTTMock(account_handler, user_handler):
self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google',
'someid', 'someemail')
def test_github_oauth(self):
@urlmatch(netloc=r'github.com', path='/login/oauth/access_token')
def account_handler(url, _):
if url.query.find("code=somecode") > 0:
content = {'access_token': 'someaccesstoken'}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'github.com', path='/api/v3/user')
def user_handler(_, __):
content = {
'id': 'someid',
'login': 'someusername'
}
return py_json.dumps(content)
@urlmatch(netloc=r'github.com', path='/api/v3/user/emails')
def email_handler(_, __):
content = [{
'email': 'someemail@example.com',
'verified': True,
'primary': True,
}]
return py_json.dumps(content)
with HTTMock(account_handler, email_handler, user_handler):
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
'someid', 'someusername')
def test_oidc_auth(self):
private_key = RSA.generate(2048)
generatedjwk = RSAKey(key=private_key.publickey()).serialize()
kid = 'somekey'
private_pem = private_key.exportKey('PEM')
token_data = {
'iss': app.config['TESTOIDC_LOGIN_CONFIG']['OIDC_SERVER'],
'aud': app.config['TESTOIDC_LOGIN_CONFIG']['CLIENT_ID'],
'nbf': int(time.time()),
'iat': int(time.time()),
'exp': int(time.time() + 600),
'sub': 'cooluser',
}
token_headers = {
'kid': kid,
}
id_token = jwt.encode(token_data, private_pem, 'RS256', headers=token_headers)
@urlmatch(netloc=r'fakeoidc', path='/token')
def token_handler(_, request):
if request.body.find("code=somecode") >= 0:
content = {'access_token': 'someaccesstoken', 'id_token': id_token}
return py_json.dumps(content)
else:
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
@urlmatch(netloc=r'fakeoidc', path='/user')
def user_handler(_, __):
content = {
'sub': 'cooluser',
'preferred_username': 'someusername',
'email': 'someemail@example.com',
'email_verified': True,
}
return py_json.dumps(content)
@urlmatch(netloc=r'fakeoidc', path='/jwks')
def jwks_handler(_, __):
jwk = generatedjwk.copy()
jwk.update({'kid': kid})
content = {'keys': [jwk]}
return py_json.dumps(content)
@urlmatch(netloc=r'fakeoidc', path='.+openid.+')
def discovery_handler(_, __):
content = {
'scopes_supported': ['profile'],
'authorization_endpoint': 'http://fakeoidc/authorize',
'token_endpoint': 'http://fakeoidc/token',
'userinfo_endpoint': 'http://fakeoidc/userinfo',
'jwks_uri': 'http://fakeoidc/jwks',
}
return py_json.dumps(content)
with HTTMock(discovery_handler, jwks_handler, token_handler, user_handler):
self.invoke_oauth_tests('testoidc_oauth_callback', 'testoidc_oauth_attach', 'testoidc',
'cooluser', 'someusername')
if __name__ == '__main__':
unittest.main()

View file

@ -81,11 +81,12 @@ class TestConfig(DefaultConfig):
FEATURE_GITHUB_LOGIN = True FEATURE_GITHUB_LOGIN = True
FEATURE_GOOGLE_LOGIN = True FEATURE_GOOGLE_LOGIN = True
FEATURE_DEX_LOGIN = True
DEX_LOGIN_CONFIG = { TESTOIDC_LOGIN_CONFIG = {
'CLIENT_ID': 'someclientid', 'CLIENT_ID': 'foo',
'OIDC_SERVER': 'https://oidcserver/', 'CLIENT_SECRET': 'bar',
'OIDC_SERVER': 'http://fakeoidc',
'DEBUGGING': True,
} }
RECAPTCHA_SITE_KEY = 'somekey' RECAPTCHA_SITE_KEY = 'somekey'