diff --git a/data/model/user.py b/data/model/user.py index ddb27400d..87e239d73 100644 --- a/data/model/user.py +++ b/data/model/user.py @@ -368,7 +368,7 @@ def update_user_metadata(user, given_name=None, family_name=None, company=None): remove_user_prompt(user, UserPromptTypes.ENTER_COMPANY) -def create_federated_user(username, email, service_name, service_ident, +def create_federated_user(username, email, service_id, service_ident, set_password_notification, metadata={}, email_required=True, prompts=tuple()): prompts = set(prompts) @@ -378,7 +378,11 @@ def create_federated_user(username, email, service_name, service_ident, new_user.verified = True new_user.save() - service = LoginService.get(LoginService.name == service_name) + try: + service = LoginService.get(LoginService.name == service_id) + except LoginService.DoesNotExist: + service = LoginService.create(name=service_id) + FederatedLogin.create(user=new_user, service=service, service_ident=service_ident, metadata_json=json.dumps(metadata)) @@ -389,20 +393,20 @@ def create_federated_user(username, email, service_name, service_ident, return new_user -def attach_federated_login(user, service_name, service_ident, metadata={}): - service = LoginService.get(LoginService.name == service_name) +def attach_federated_login(user, service_id, service_ident, metadata={}): + service = LoginService.get(LoginService.name == service_id) FederatedLogin.create(user=user, service=service, service_ident=service_ident, metadata_json=json.dumps(metadata)) return user -def verify_federated_login(service_name, service_ident): +def verify_federated_login(service_id, service_ident): try: found = (FederatedLogin .select(FederatedLogin, User) .join(LoginService) .switch(FederatedLogin).join(User) - .where(FederatedLogin.service_ident == service_ident, LoginService.name == service_name) + .where(FederatedLogin.service_ident == service_ident, LoginService.name == service_id) .get()) return found.user except FederatedLogin.DoesNotExist: diff --git a/endpoints/common.py b/endpoints/common.py index fd52ee868..30fd11947 100644 --- a/endpoints/common.py +++ b/endpoints/common.py @@ -197,7 +197,6 @@ def render_page_template(name, route_data=None, **kwargs): 'title': login_service.service_name(), 'config': login_service.get_public_config(), 'icon': login_service.get_icon(), - 'scopes': login_service.get_login_scopes(), }) return login_config diff --git a/endpoints/oauthlogin.py b/endpoints/oauthlogin.py index 0c7865213..f1bffc064 100644 --- a/endpoints/oauthlogin.py +++ b/endpoints/oauthlogin.py @@ -1,4 +1,5 @@ import logging +import uuid from flask import request, redirect, url_for, Blueprint from peewee import IntegrityError @@ -50,6 +51,7 @@ def _conduct_oauth_login(service_id, service_name, user_id, username, email, met # Try to create the user try: + # Generate a valid username. new_username = None for valid in generate_valid_usernames(username): if model.user.get_user_or_org(valid): @@ -58,6 +60,11 @@ def _conduct_oauth_login(service_id, service_name, user_id, username, email, met new_username = valid break + # Generate a valid email. If the email is None and the MAILING feature is turned + # off, simply place in a fake email address. + if email is None and not features.MAILING: + email = '%s@fake.example.com' % (str(uuid.uuid4())) + prompts = model.user.get_default_user_prompts(features) to_login = model.user.create_federated_user(new_username, email, service_id, user_id, set_password_notification=True, @@ -102,6 +109,7 @@ def _register_service(login_service): try: lid, lusername, lemail = login_service.exchange_code_for_login(app.config, client, code, '') except OAuthLoginException as ole: + logger.exception('Got login exception') return _render_ologin_error(login_service.service_name(), ole.message) # Conduct login. diff --git a/external_libraries.py b/external_libraries.py index d99431283..f2f9c3832 100644 --- a/external_libraries.py +++ b/external_libraries.py @@ -22,7 +22,7 @@ EXTERNAL_JS = [ ] EXTERNAL_CSS = [ - 'netdna.bootstrapcdn.com/font-awesome/4.6.0/css/font-awesome.css', + 'netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css', 'netdna.bootstrapcdn.com/bootstrap/3.3.2/css/bootstrap.min.css', 'fonts.googleapis.com/css?family=Source+Sans+Pro:300,400,700', 's3.amazonaws.com/cdn.core-os.net/icons/core-icons.css', diff --git a/oauth/base.py b/oauth/base.py index f25bbc6ab..93b1c8a3c 100644 --- a/oauth/base.py +++ b/oauth/base.py @@ -51,7 +51,7 @@ class OAuthService(object): def get_redirect_uri(self, app_config, redirect_suffix=''): return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'], app_config['SERVER_HOSTNAME'], - self.service_name().lower(), + self.service_id(), redirect_suffix) def get_user_info(self, http_client, token): @@ -74,8 +74,8 @@ class OAuthService(object): def exchange_code_for_token(self, app_config, http_client, code, form_encode=False, redirect_suffix='', client_auth=False): """ Exchanges an OAuth access code for the associated OAuth token. """ - json_data = self._exchange_code(app_config, http_client, code, form_encode, redirect_suffix, - client_auth) + json_data = self.exchange_code(app_config, http_client, code, form_encode, redirect_suffix, + client_auth) access_token = json_data.get('access_token', None) if access_token is None: @@ -84,8 +84,9 @@ class OAuthService(object): return access_token - def _exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='', - client_auth=False): + def exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='', + client_auth=False): + """ Exchanges an OAuth access code for associated OAuth token and other data. """ payload = { 'code': code, 'grant_type': 'authorization_code', diff --git a/oauth/login.py b/oauth/login.py index a4aa5524e..268d030f7 100644 --- a/oauth/login.py +++ b/oauth/login.py @@ -14,6 +14,10 @@ class OAuthLoginService(OAuthService): """ A base class for defining an OAuth-compliant service that can be used for, amongst other things, login and authentication. """ + def login_enabled(self): + """ Returns true if the login service is enabled. """ + raise NotImplementedError + def get_login_service_id(self, user_info): """ Returns the internal ID for the given user under this login service. """ raise NotImplementedError diff --git a/oauth/loginmanager.py b/oauth/loginmanager.py index 6e15ba7d8..2443d20dd 100644 --- a/oauth/loginmanager.py +++ b/oauth/loginmanager.py @@ -1,7 +1,11 @@ -import features - from oauth.services.github import GithubOAuthService from oauth.services.google import GoogleOAuthService +from oauth.oidc import OIDCLoginService + +CUSTOM_LOGIN_SERVICES = { + 'GITHUB_LOGIN_CONFIG': GithubOAuthService, + 'GOOGLE_LOGIN_CONFIG': GoogleOAuthService, +} class OAuthLoginManager(object): """ Helper class which manages all registered OAuth login services. """ @@ -9,11 +13,12 @@ class OAuthLoginManager(object): self.services = [] # Register the endpoints for each of the OAuth login services. - # TODO(jschorr): make this dynamic. - if config.get('GITHUB_LOGIN_CONFIG') is not None and features.GITHUB_LOGIN: - github_service = GithubOAuthService(config, 'GITHUB_LOGIN_CONFIG') - self.services.append(github_service) - - if config.get('GOOGLE_LOGIN_CONFIG') is not None and features.GOOGLE_LOGIN: - google_service = GoogleOAuthService(config, 'GOOGLE_LOGIN_CONFIG') - self.services.append(google_service) + for key in config.keys(): + # All keys which end in _LOGIN_CONFIG setup a login service. + if key.endswith('_LOGIN_CONFIG'): + if key in CUSTOM_LOGIN_SERVICES: + custom_service = CUSTOM_LOGIN_SERVICES[key](config, key) + if custom_service.login_enabled(config): + self.services.append(custom_service) + else: + self.services.append(OIDCLoginService(config, key)) diff --git a/oauth/oidc.py b/oauth/oidc.py index 161be7418..b51668a06 100644 --- a/oauth/oidc.py +++ b/oauth/oidc.py @@ -3,108 +3,244 @@ import json import logging import urlparse +import jwt + from cachetools import lru_cache from cachetools.ttl import TTLCache +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.serialization import load_der_public_key +from jwkest.jwk import KEYS -from util.oauth.base import OAuthService +from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException +from oauth.login import OAuthLoginException +from util.security.jwtutil import decode, InvalidTokenError +from util import get_app_url logger = logging.getLogger(__name__) -def decode_user_jwt(token, oidc_provider): - try: - return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'], - audience=oidc_provider.client_id(), - issuer=oidc_provider.issuer) - except InvalidTokenError: - # Public key may have expired. Try to retrieve an updated public key and use it to decode. - return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'], - audience=oidc_provider.client_id(), - issuer=oidc_provider.issuer) - OIDC_WELLKNOWN = ".well-known/openid-configuration" PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour +ALLOWED_ALGORITHMS = ['RS256'] +JWT_CLOCK_SKEW_SECONDS = 30 -class OIDCConfig(OAuthService): +class DiscoveryFailureException(Exception): + """ Exception raised when OIDC discovery fails. """ + pass + + +class PublicKeyLoadException(Exception): + """ Exception raised if loading the OIDC public key fails. """ + pass + + +class OIDCLoginService(OAuthService): + """ Defines a generic service for all OpenID-connect compatible login services. """ def __init__(self, config, key_name): - super(OIDCConfig, self).__init__(config, key_name) + super(OIDCLoginService, self).__init__(config, key_name) - self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key) - self._config = config + self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._load_public_key) + self._id = key_name[0:key_name.find('_')].lower() self._http_client = config['HTTPCLIENT'] + self._mailing = config.get('FEATURE_MAILING', False) - @lru_cache(maxsize=1) - def _oidc_config(self): - if self.config.get('OIDC_SERVER'): - return self._load_via_discovery(self._config.get('DEBUGGING', False)) - else: - return {} + def service_id(self): + return self._id - def _load_via_discovery(self, is_debugging): - oidc_server = self.config['OIDC_SERVER'] - if not oidc_server.startswith('https://') and not is_debugging: - raise Exception('OIDC server must be accessed over SSL') + def service_name(self): + return self.config.get('SERVICE_NAME', self.service_id()) - discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN) - discovery = self._http_client.get(discovery_url, timeout=5) + def get_icon(self): + return self.config.get('SERVICE_ICON', 'fa-user-circle') - if discovery.status_code / 100 != 2: - raise Exception("Could not load OIDC discovery information") + def get_login_scopes(self): + default_scopes = ['openid'] - try: - return json.loads(discovery.text) - except ValueError: - logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) - raise Exception("Could not parse OIDC discovery information") + if self.user_endpoint() is not None: + default_scopes.append('profile') + + if self._mailing: + default_scopes.append('email') + + return self._oidc_config().get('scopes_supported', default_scopes) def authorize_endpoint(self): - return self._oidc_config().get('authorization_endpoint', '') + '?' + return self._oidc_config().get('authorization_endpoint', '') + '?response_type=code&' def token_endpoint(self): return self._oidc_config().get('token_endpoint') def user_endpoint(self): - return None + return self._oidc_config().get('userinfo_endpoint') def validate_client_id_and_secret(self, http_client, app_config): - pass + # TODO: find a way to verify client secret too. + redirect_url = '%s/oauth2/%s/callback' % (get_app_url(app_config), self.service_id()) + scopes_string = ' '.join(self.get_login_scopes()) + authorize_url = '%sclient_id=%s&redirect_uri=%s&scope=%s' % (self.authorize_endpoint(), + self.client_id(), + redirect_url, + scopes_string) + + check_auth_url = http_client.get(authorize_url) + if check_auth_url.status_code // 100 != 2: + raise Exception('Got non-200 status code for authorization endpoint') + + def requires_form_encoding(self): + return True def get_public_config(self): return { 'CLIENT_ID': self.client_id(), - 'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), 'OIDC': True, } + def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix): + # Exchange the code for the access token and id_token + try: + json_data = self.exchange_code(app_config, http_client, code, + redirect_suffix=redirect_suffix, + form_encode=self.requires_form_encoding()) + except OAuthExchangeCodeException as oce: + raise OAuthLoginException(oce.message) + + # Make sure we received both. + access_token = json_data.get('access_token', None) + if access_token is None: + logger.debug('Missing access_token in response: %s', json_data) + raise OAuthLoginException('Missing `access_token` in OIDC response') + + id_token = json_data.get('id_token', None) + if id_token is None: + logger.debug('Missing id_token in response: %s', json_data) + raise OAuthLoginException('Missing `id_token` in OIDC response') + + # Decode the id_token. + try: + decoded_id_token = self._decode_user_jwt(id_token) + except InvalidTokenError as ite: + logger.exception('Got invalid token error on OIDC decode: %s', ite.message) + raise OAuthLoginException('Could not decode OIDC token') + except PublicKeyLoadException as pke: + logger.exception('Could not load public key during OIDC decode: %s', pke.message) + raise OAuthLoginException('Could find public OIDC key') + + # Retrieve the user information. + try: + user_info = self.get_user_info(http_client, access_token) + except OAuthGetUserInfoException as oge: + raise OAuthLoginException(oge.message) + + # Verify subs. + if user_info['sub'] != decoded_id_token['sub']: + raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint') + + # Check if we have a verified email address. + email_address = user_info.get('email') if user_info.get('email_verified') else None + if self._mailing: + if email_address is None: + raise OAuthLoginException('A verified email address is required to login with this service') + + # Check for a preferred username. + lusername = user_info.get('preferred_username') or user_info.get('sub') + return decoded_id_token['sub'], lusername, email_address + @property - def issuer(self): + def _issuer(self): return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER']) - def get_public_key(self, force_refresh=False): - """ Retrieves the public key for this handler. """ + @lru_cache(maxsize=1) + def _oidc_config(self): + if self.config.get('OIDC_SERVER'): + return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False)) + else: + return {} + + def _load_oidc_config_via_discovery(self, is_debugging): + """ Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True, + non-secure connections are alllowed. Raises an DiscoveryFailureException on failure. + """ + oidc_server = self.config['OIDC_SERVER'] + if not oidc_server.startswith('https://') and not is_debugging: + raise DiscoveryFailureException('OIDC server must be accessed over SSL') + + discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN) + discovery = self._http_client.get(discovery_url, timeout=5, verify=not is_debugging) + if discovery.status_code // 100 != 2: + logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text) + raise DiscoveryFailureException("Could not load OIDC discovery information") + + try: + return json.loads(discovery.text) + except ValueError: + logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) + raise DiscoveryFailureException("Could not parse OIDC discovery information") + + def _decode_user_jwt(self, token): + """ Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError + exception on an invalid token or a PublicKeyLoadException if the public key could not be + loaded for decoding. + """ + # Find the key to use. + headers = jwt.get_unverified_header(token) + kid = headers.get('kid', None) + if kid is None: + raise InvalidTokenError('Missing `kid` header') + + try: + return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS, + audience=self.client_id(), + issuer=self._issuer, + leeway=JWT_CLOCK_SKEW_SECONDS, + options=dict(require_nbf=False)) + except InvalidTokenError: + # Public key may have expired. Try to retrieve an updated public key and use it to decode. + return decode(token, self._get_public_key(kid, force_refresh=True), + algorithms=ALLOWED_ALGORITHMS, + audience=self.client_id(), + issuer=self._issuer, + leeway=JWT_CLOCK_SKEW_SECONDS, + options=dict(require_nbf=False)) + + def _get_public_key(self, kid, force_refresh=False): + """ Retrieves the public key for this handler with the given kid. Raises a + PublicKeyLoadException on failure. """ + # If force_refresh is true, we expire all the items in the cache by setting the time to # the current time + the expiration TTL. if force_refresh: self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL) # Retrieve the public key from the cache. If the cache does not contain the public key, - # it will internally call _get_public_key to retrieve it and then save it. The None is - # a random key chose to be stored in the cache, and could be anything. - return self._public_key_cache[None] + # it will internally call _load_public_key to retrieve it and then save it. + return self._public_key_cache[kid] - def _get_public_key(self, _): - """ Retrieves the public key for this handler. """ + def _load_public_key(self, kid): + """ Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException + on failure. + """ keys_url = self._oidc_config()['jwks_uri'] - keys = KEYS() - keys.load_from_url(keys_url) + # Load the keys. + try: + keys = KEYS() + keys.load_from_url(keys_url, verify=not self.config.get('DEBUGGING', False)) + except Exception as ex: + logger.exception('Exception loading public key') + raise PublicKeyLoadException(ex.message) - if not list(keys): - raise Exception('No keys provided by OIDC provider') + # Find the matching key. + keys_found = keys.by_kid(kid) + if len(keys_found) == 0: + raise PublicKeyLoadException('Public key %s not found' % kid) - rsa_key = list(keys)[0] - rsa_key.deserialize() + rsa_keys = [key for key in keys_found if key.kty == 'RSA'] + if len(rsa_keys) == 0: + raise PublicKeyLoadException('No RSA form of public key %s not found' % kid) + + matching_key = rsa_keys[0] + matching_key.deserialize() # Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing # issues. - return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend()) + return load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend()) diff --git a/oauth/services/github.py b/oauth/services/github.py index cb4f3977d..da25775f7 100644 --- a/oauth/services/github.py +++ b/oauth/services/github.py @@ -9,6 +9,9 @@ class GithubOAuthService(OAuthLoginService): def __init__(self, config, key_name): super(GithubOAuthService, self).__init__(config, key_name) + def login_enabled(self, config): + return config.get('FEATURE_GITHUB_LOGIN', False) + def service_id(self): return 'github' diff --git a/oauth/services/google.py b/oauth/services/google.py index 19c6d27de..aaedfbac0 100644 --- a/oauth/services/google.py +++ b/oauth/services/google.py @@ -12,6 +12,9 @@ class GoogleOAuthService(OAuthLoginService): def __init__(self, config, key_name): super(GoogleOAuthService, self).__init__(config, key_name) + def login_enabled(self, config): + return config.get('FEATURE_GOOGLE_LOGIN', False) + def service_id(self): return 'google' diff --git a/oauth/test/test_loginmanager.py b/oauth/test/test_loginmanager.py new file mode 100644 index 000000000..491216104 --- /dev/null +++ b/oauth/test/test_loginmanager.py @@ -0,0 +1,62 @@ +from oauth.loginmanager import OAuthLoginManager +from oauth.services.github import GithubOAuthService +from oauth.services.google import GoogleOAuthService +from oauth.oidc import OIDCLoginService + +def test_login_manager_github(): + config = { + 'FEATURE_GITHUB_LOGIN': True, + 'GITHUB_LOGIN_CONFIG': {}, + } + + loginmanager = OAuthLoginManager(config) + assert len(loginmanager.services) == 1 + assert isinstance(loginmanager.services[0], GithubOAuthService) + +def test_github_disabled(): + config = { + 'GITHUB_LOGIN_CONFIG': {}, + } + + loginmanager = OAuthLoginManager(config) + assert len(loginmanager.services) == 0 + +def test_login_manager_google(): + config = { + 'FEATURE_GOOGLE_LOGIN': True, + 'GOOGLE_LOGIN_CONFIG': {}, + } + + loginmanager = OAuthLoginManager(config) + assert len(loginmanager.services) == 1 + assert isinstance(loginmanager.services[0], GoogleOAuthService) + +def test_google_disabled(): + config = { + 'GOOGLE_LOGIN_CONFIG': {}, + } + + loginmanager = OAuthLoginManager(config) + assert len(loginmanager.services) == 0 + +def test_oidc(): + config = { + 'SOMECOOL_LOGIN_CONFIG': {}, + 'HTTPCLIENT': None, + } + + loginmanager = OAuthLoginManager(config) + assert len(loginmanager.services) == 1 + assert isinstance(loginmanager.services[0], OIDCLoginService) + +def test_multiple_oidc(): + config = { + 'SOMECOOL_LOGIN_CONFIG': {}, + 'ANOTHER_LOGIN_CONFIG': {}, + 'HTTPCLIENT': None, + } + + loginmanager = OAuthLoginManager(config) + assert len(loginmanager.services) == 2 + assert isinstance(loginmanager.services[0], OIDCLoginService) + assert isinstance(loginmanager.services[1], OIDCLoginService) diff --git a/oauth/test/test_oidc.py b/oauth/test/test_oidc.py new file mode 100644 index 000000000..cf9612136 --- /dev/null +++ b/oauth/test/test_oidc.py @@ -0,0 +1,273 @@ +# pylint: disable=redefined-outer-name, unused-argument, C0103, C0111, too-many-arguments + +import json +import time +import urlparse + +import jwt +import pytest +import requests + +from httmock import urlmatch, HTTMock +from Crypto.PublicKey import RSA +from jwkest.jwk import RSAKey + +from oauth.oidc import OIDCLoginService, OAuthLoginException + +@pytest.fixture() +def http_client(): + sess = requests.Session() + adapter = requests.adapters.HTTPAdapter(pool_connections=100, + pool_maxsize=100) + sess.mount('http://', adapter) + sess.mount('https://', adapter) + return sess + +@pytest.fixture(params=[True, False]) +def app_config(http_client, request): + return { + 'PREFERRED_URL_SCHEME': 'http', + 'SERVER_HOSTNAME': 'localhost', + 'FEATURE_MAILING': request.param, + + 'SOMEOIDC_TEST_SERVICE': { + 'CLIENT_ID': 'foo', + 'CLIENT_SECRET': 'bar', + 'SERVICE_NAME': 'Some Cool Service', + 'SERVICE_ICON': 'http://some/icon', + 'OIDC_SERVER': 'http://fakeoidc', + 'DEBUGGING': True, + }, + + 'HTTPCLIENT': http_client, + } + +@pytest.fixture() +def oidc_service(app_config): + return OIDCLoginService(app_config, 'SOMEOIDC_TEST_SERVICE') + +@pytest.fixture() +def discovery_content(): + return { + 'scopes_supported': ['profile'], + 'authorization_endpoint': 'http://fakeoidc/authorize', + 'token_endpoint': 'http://fakeoidc/token', + 'userinfo_endpoint': 'http://fakeoidc/userinfo', + 'jwks_uri': 'http://fakeoidc/jwks', + } + +@pytest.fixture() +def discovery_handler(discovery_content): + @urlmatch(netloc=r'fakeoidc', path=r'.+openid.+') + def handler(_, __): + return json.dumps(discovery_content) + + return handler + +@pytest.fixture(scope="module") # Slow to generate, only do it once. +def signing_key(): + private_key = RSA.generate(2048) + jwk = RSAKey(key=private_key.publickey()).serialize() + return { + 'id': 'somekey', + 'private_key': private_key.exportKey('PEM'), + 'jwk': jwk, + } + +@pytest.fixture() +def id_token(oidc_service, signing_key, app_config): + token_data = { + 'iss': oidc_service.config['OIDC_SERVER'], + 'aud': oidc_service.client_id(), + 'nbf': int(time.time()), + 'iat': int(time.time()), + 'exp': int(time.time() + 600), + 'sub': 'cooluser', + } + + token_headers = { + 'kid': signing_key['id'], + } + + return jwt.encode(token_data, signing_key['private_key'], 'RS256', headers=token_headers) + +@pytest.fixture() +def valid_code(): + return 'validcode' + +@pytest.fixture() +def token_handler(oidc_service, id_token, valid_code): + @urlmatch(netloc=r'fakeoidc', path=r'/token') + def handler(_, request): + params = urlparse.parse_qs(request.body) + if params.get('redirect_uri')[0] != 'http://localhost/oauth2/someoidc/callback': + return {'status_code': 400, 'content': 'Invalid redirect URI'} + + if params.get('client_id')[0] != oidc_service.client_id(): + return {'status_code': 401, 'content': 'Invalid client id'} + + if params.get('client_secret')[0] != oidc_service.client_secret(): + return {'status_code': 401, 'content': 'Invalid client secret'} + + if params.get('code')[0] != valid_code: + return {'status_code': 401, 'content': 'Invalid code'} + + if params.get('grant_type')[0] != 'authorization_code': + return {'status_code': 400, 'content': 'Invalid authorization type'} + + content = { + 'access_token': 'sometoken', + 'id_token': id_token, + } + return {'status_code': 200, 'content': json.dumps(content)} + + return handler + +@pytest.fixture() +def jwks_handler(signing_key): + def jwk_with_kid(kid, jwk): + jwk = jwk.copy() + jwk.update({'kid': kid}) + return jwk + + @urlmatch(netloc=r'fakeoidc', path=r'/jwks') + def handler(_, __): + content = {'keys': [jwk_with_kid(signing_key['id'], signing_key['jwk'])]} + return {'status_code': 200, 'content': json.dumps(content)} + + return handler + +@pytest.fixture() +def emptykeys_jwks_handler(): + @urlmatch(netloc=r'fakeoidc', path=r'/jwks') + def handler(_, __): + content = {'keys': []} + return {'status_code': 200, 'content': json.dumps(content)} + + return handler + +@pytest.fixture(params=["someusername", None]) +def preferred_username(request): + return request.param + +@pytest.fixture +def userinfo_handler(oidc_service, preferred_username): + @urlmatch(netloc=r'fakeoidc', path=r'/userinfo') + def handler(_, __): + content = { + 'sub': 'cooluser', + 'preferred_username':preferred_username, + 'email': 'foo@example.com', + 'email_verified': True, + } + + return {'status_code': 200, 'content': json.dumps(content)} + + return handler + +@pytest.fixture() +def invalidsub_userinfo_handler(oidc_service): + @urlmatch(netloc=r'fakeoidc', path=r'/userinfo') + def handler(_, __): + content = { + 'sub': 'invalidsub', + 'preferred_username': 'someusername', + 'email': 'foo@example.com', + 'email_verified': True, + } + + return {'status_code': 200, 'content': json.dumps(content)} + + return handler + +@pytest.fixture() +def missingemail_userinfo_handler(oidc_service, preferred_username): + @urlmatch(netloc=r'fakeoidc', path=r'/userinfo') + def handler(_, __): + content = { + 'sub': 'cooluser', + 'preferred_username': preferred_username, + } + + return {'status_code': 200, 'content': json.dumps(content)} + + return handler + +def test_basic_config(oidc_service): + assert oidc_service.service_id() == 'someoidc' + assert oidc_service.service_name() == 'Some Cool Service' + assert oidc_service.get_icon() == 'http://some/icon' + +def test_discovery(oidc_service, http_client, discovery_handler): + with HTTMock(discovery_handler): + assert oidc_service.authorize_endpoint() == 'http://fakeoidc/authorize?response_type=code&' + assert oidc_service.token_endpoint() == 'http://fakeoidc/token' + assert oidc_service.user_endpoint() == 'http://fakeoidc/userinfo' + assert oidc_service.get_login_scopes() == ['profile'] + +def test_public_config(oidc_service, discovery_handler): + with HTTMock(discovery_handler): + assert oidc_service.get_public_config()['OIDC'] + assert oidc_service.get_public_config()['CLIENT_ID'] == 'foo' + + assert 'CLIENT_SECRET' not in oidc_service.get_public_config() + assert 'bar' not in oidc_service.get_public_config().values() + +def test_exchange_code_invalidcode(oidc_service, discovery_handler, app_config, http_client, + token_handler): + with HTTMock(token_handler, discovery_handler): + with pytest.raises(OAuthLoginException): + oidc_service.exchange_code_for_login(app_config, http_client, 'testcode', '') + +def test_exchange_code_validcode(oidc_service, discovery_handler, app_config, http_client, + token_handler, userinfo_handler, jwks_handler, valid_code, + preferred_username): + with HTTMock(jwks_handler, token_handler, userinfo_handler, discovery_handler): + lid, lusername, lemail = oidc_service.exchange_code_for_login(app_config, http_client, + valid_code, '') + + assert lid == 'cooluser' + assert lemail == 'foo@example.com' + + if preferred_username is not None: + assert lusername == preferred_username + else: + assert lusername == lid + +def test_exchange_code_missingemail(oidc_service, discovery_handler, app_config, http_client, + token_handler, missingemail_userinfo_handler, jwks_handler, + valid_code, preferred_username): + with HTTMock(jwks_handler, token_handler, missingemail_userinfo_handler, discovery_handler): + if app_config['FEATURE_MAILING']: + # Should fail because there is no valid email address. + with pytest.raises(OAuthLoginException): + oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '') + else: + # Should succeed because, while there is no valid email address, it isn't necessary with + # mailing disabled. + lid, lusername, lemail = oidc_service.exchange_code_for_login(app_config, http_client, + valid_code, '') + + assert lid == 'cooluser' + assert lemail is None + + if preferred_username is not None: + assert lusername == preferred_username + else: + assert lusername == lid + +def test_exchange_code_invalidsub(oidc_service, discovery_handler, app_config, http_client, + token_handler, invalidsub_userinfo_handler, jwks_handler, + valid_code): + with HTTMock(jwks_handler, token_handler, invalidsub_userinfo_handler, discovery_handler): + # Should fail because the sub of the user info doesn't match that returned by the id_token. + with pytest.raises(OAuthLoginException): + oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '') + +def test_exchange_code_missingkey(oidc_service, discovery_handler, app_config, http_client, + token_handler, userinfo_handler, emptykeys_jwks_handler, + valid_code): + with HTTMock(emptykeys_jwks_handler, token_handler, userinfo_handler, discovery_handler): + # Should fail because the key is missing. + with pytest.raises(OAuthLoginException): + oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '') diff --git a/test/test_endpoints.py b/test/test_endpoints.py index 216021245..256b54bd4 100644 --- a/test/test_endpoints.py +++ b/test/test_endpoints.py @@ -8,7 +8,6 @@ import base64 from urllib import urlencode from urlparse import urlparse, urlunparse, parse_qs from datetime import datetime, timedelta -from httmock import urlmatch, HTTMock import jwt @@ -16,7 +15,7 @@ from Crypto.PublicKey import RSA from flask import url_for from jwkest.jwk import RSAKey -from app import app, oauth_login +from app import app from data import model from data.database import ServiceKeyApprovalType from endpoints import keyserver @@ -25,16 +24,9 @@ from endpoints.api.user import Signin from endpoints.keyserver import jwk_with_kid from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME from endpoints.web import web as web_bp -from endpoints.oauthlogin import oauthlogin as oauthlogin_bp from initdb import setup_database_for_testing, finished_database_for_testing from test.helpers import assert_action_logged -try: - app.register_blueprint(oauthlogin_bp, url_prefix='/oauth2') -except ValueError: - # This blueprint was already registered - pass - try: app.register_blueprint(web_bp, url_prefix='') except ValueError: @@ -129,96 +121,6 @@ class EndpointTestCase(unittest.TestCase): self.assertEquals(rv.status_code, 200) -class OAuthLoginTestCase(EndpointTestCase): - def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident, - new_username): - # Test callback. - created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username) - - # Delete the created user. - model.user.delete_user(created, []) - - # Test attach. - self.login('devtable', 'password') - self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable') - - def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username): - # No CSRF. - self.getResponse('oauthlogin.' + endpoint_name, expected_code=403) - - # Invalid CSRF. - self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403) - - # Valid CSRF, invalid code. - self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken', - code='invalidcode', expected_code=400) - - # Valid CSRF, valid code. - self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken', - code='somecode', expected_code=302) - - # Ensure the user was added/modified. - found_user = model.user.get_user(username) - self.assertIsNotNone(found_user) - - federated_login = model.user.lookup_federated_login(found_user, service_name) - self.assertIsNotNone(federated_login) - self.assertEquals(federated_login.service_ident, service_ident) - return found_user - - def test_google_oauth(self): - @urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token') - def account_handler(_, request): - if request.body.find("code=somecode") > 0: - content = {'access_token': 'someaccesstoken'} - return py_json.dumps(content) - else: - return {'status_code': 400, 'content': '{"message": "Invalid code"}'} - - @urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo') - def user_handler(_, __): - content = { - 'id': 'someid', - 'email': 'someemail@example.com', - 'verified_email': True, - } - return py_json.dumps(content) - - with HTTMock(account_handler, user_handler): - self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google', - 'someid', 'someemail') - - def test_github_oauth(self): - @urlmatch(netloc=r'github.com', path='/login/oauth/access_token') - def account_handler(url, _): - if url.query.find("code=somecode") > 0: - content = {'access_token': 'someaccesstoken'} - return py_json.dumps(content) - else: - return {'status_code': 400, 'content': '{"message": "Invalid code"}'} - - @urlmatch(netloc=r'github.com', path='/api/v3/user') - def user_handler(_, __): - content = { - 'id': 'someid', - 'login': 'someusername' - } - return py_json.dumps(content) - - @urlmatch(netloc=r'github.com', path='/api/v3/user/emails') - def email_handler(_, __): - content = [{ - 'email': 'someemail@example.com', - 'verified': True, - 'primary': True, - }] - return py_json.dumps(content) - - with HTTMock(account_handler, email_handler, user_handler): - self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github', - 'someid', 'someusername') - - class WebEndpointTestCase(EndpointTestCase): def test_index(self): self.getResponse('web.index') diff --git a/test/test_oauth_login.py b/test/test_oauth_login.py new file mode 100644 index 000000000..9590a39aa --- /dev/null +++ b/test/test_oauth_login.py @@ -0,0 +1,175 @@ +import json as py_json +import time +import unittest + +import jwt + +from Crypto.PublicKey import RSA +from httmock import urlmatch, HTTMock +from jwkest.jwk import RSAKey + +from app import app +from data import model +from endpoints.oauthlogin import oauthlogin as oauthlogin_bp +from test.test_endpoints import EndpointTestCase + +try: + app.register_blueprint(oauthlogin_bp, url_prefix='/oauth2') +except ValueError: + # This blueprint was already registered + pass + +class OAuthLoginTestCase(EndpointTestCase): + def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident, + new_username): + # Test callback. + created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username) + + # Delete the created user. + model.user.delete_user(created, []) + + # Test attach. + self.login('devtable', 'password') + self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable') + + def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username): + # No CSRF. + self.getResponse('oauthlogin.' + endpoint_name, expected_code=403) + + # Invalid CSRF. + self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403) + + # Valid CSRF, invalid code. + self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken', + code='invalidcode', expected_code=400) + + # Valid CSRF, valid code. + self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken', + code='somecode', expected_code=302) + + # Ensure the user was added/modified. + found_user = model.user.get_user(username) + self.assertIsNotNone(found_user) + + federated_login = model.user.lookup_federated_login(found_user, service_name) + self.assertIsNotNone(federated_login) + self.assertEquals(federated_login.service_ident, service_ident) + return found_user + + def test_google_oauth(self): + @urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token') + def account_handler(_, request): + if request.body.find("code=somecode") > 0: + content = {'access_token': 'someaccesstoken'} + return py_json.dumps(content) + else: + return {'status_code': 400, 'content': '{"message": "Invalid code"}'} + + @urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo') + def user_handler(_, __): + content = { + 'id': 'someid', + 'email': 'someemail@example.com', + 'verified_email': True, + } + return py_json.dumps(content) + + with HTTMock(account_handler, user_handler): + self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google', + 'someid', 'someemail') + + def test_github_oauth(self): + @urlmatch(netloc=r'github.com', path='/login/oauth/access_token') + def account_handler(url, _): + if url.query.find("code=somecode") > 0: + content = {'access_token': 'someaccesstoken'} + return py_json.dumps(content) + else: + return {'status_code': 400, 'content': '{"message": "Invalid code"}'} + + @urlmatch(netloc=r'github.com', path='/api/v3/user') + def user_handler(_, __): + content = { + 'id': 'someid', + 'login': 'someusername' + } + return py_json.dumps(content) + + @urlmatch(netloc=r'github.com', path='/api/v3/user/emails') + def email_handler(_, __): + content = [{ + 'email': 'someemail@example.com', + 'verified': True, + 'primary': True, + }] + return py_json.dumps(content) + + with HTTMock(account_handler, email_handler, user_handler): + self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github', + 'someid', 'someusername') + + def test_oidc_auth(self): + private_key = RSA.generate(2048) + generatedjwk = RSAKey(key=private_key.publickey()).serialize() + kid = 'somekey' + private_pem = private_key.exportKey('PEM') + + token_data = { + 'iss': app.config['TESTOIDC_LOGIN_CONFIG']['OIDC_SERVER'], + 'aud': app.config['TESTOIDC_LOGIN_CONFIG']['CLIENT_ID'], + 'nbf': int(time.time()), + 'iat': int(time.time()), + 'exp': int(time.time() + 600), + 'sub': 'cooluser', + } + + token_headers = { + 'kid': kid, + } + + id_token = jwt.encode(token_data, private_pem, 'RS256', headers=token_headers) + + @urlmatch(netloc=r'fakeoidc', path='/token') + def token_handler(_, request): + if request.body.find("code=somecode") >= 0: + content = {'access_token': 'someaccesstoken', 'id_token': id_token} + return py_json.dumps(content) + else: + return {'status_code': 400, 'content': '{"message": "Invalid code"}'} + + @urlmatch(netloc=r'fakeoidc', path='/user') + def user_handler(_, __): + content = { + 'sub': 'cooluser', + 'preferred_username': 'someusername', + 'email': 'someemail@example.com', + 'email_verified': True, + } + return py_json.dumps(content) + + @urlmatch(netloc=r'fakeoidc', path='/jwks') + def jwks_handler(_, __): + jwk = generatedjwk.copy() + jwk.update({'kid': kid}) + + content = {'keys': [jwk]} + return py_json.dumps(content) + + @urlmatch(netloc=r'fakeoidc', path='.+openid.+') + def discovery_handler(_, __): + content = { + 'scopes_supported': ['profile'], + 'authorization_endpoint': 'http://fakeoidc/authorize', + 'token_endpoint': 'http://fakeoidc/token', + 'userinfo_endpoint': 'http://fakeoidc/userinfo', + 'jwks_uri': 'http://fakeoidc/jwks', + } + return py_json.dumps(content) + + with HTTMock(discovery_handler, jwks_handler, token_handler, user_handler): + self.invoke_oauth_tests('testoidc_oauth_callback', 'testoidc_oauth_attach', 'testoidc', + 'cooluser', 'someusername') + +if __name__ == '__main__': + unittest.main() + diff --git a/test/testconfig.py b/test/testconfig.py index eb976aa65..e9d6c03db 100644 --- a/test/testconfig.py +++ b/test/testconfig.py @@ -81,11 +81,12 @@ class TestConfig(DefaultConfig): FEATURE_GITHUB_LOGIN = True FEATURE_GOOGLE_LOGIN = True - FEATURE_DEX_LOGIN = True - DEX_LOGIN_CONFIG = { - 'CLIENT_ID': 'someclientid', - 'OIDC_SERVER': 'https://oidcserver/', + TESTOIDC_LOGIN_CONFIG = { + 'CLIENT_ID': 'foo', + 'CLIENT_SECRET': 'bar', + 'OIDC_SERVER': 'http://fakeoidc', + 'DEBUGGING': True, } RECAPTCHA_SITE_KEY = 'somekey'