diff --git a/oauth/oidc.py b/oauth/oidc.py index e19e5c466..232ef68ae 100644 --- a/oauth/oidc.py +++ b/oauth/oidc.py @@ -2,6 +2,7 @@ import time import json import logging import urlparse +import urllib import jwt @@ -65,13 +66,35 @@ class OIDCLoginService(OAuthService): return list(set(login_scopes) & set(supported_scopes)) def authorize_endpoint(self): - return self._oidc_config().get('authorization_endpoint', '') + '?response_type=code&' + return self._get_endpoint('authorization_endpoint', response_type='code') def token_endpoint(self): - return self._oidc_config().get('token_endpoint') + return self._get_endpoint('token_endpoint') def user_endpoint(self): - return self._oidc_config().get('userinfo_endpoint') + return self._get_endpoint('userinfo_endpoint') + + def _get_endpoint(self, endpoint_key, **kwargs): + """ Returns the OIDC endpoint with the given key found in the OIDC discovery + document, with the given kwargs added as query parameters. Additionally, + any defined parameters found in the OIDC configuration block are also + added. + """ + endpoint = self._oidc_config().get(endpoint_key, '') + if not endpoint: + return None + + (scheme, netloc, path, query, fragment) = urlparse.urlsplit(endpoint) + + # Add the query parameters from the kwargs and the config. + custom_parameters = self.config.get('OIDC_ENDPOINT_CUSTOM_PARAMS', {}).get(endpoint_key, {}) + + query_params = urlparse.parse_qs(query, keep_blank_values=True) + query_params.update(kwargs) + query_params.update(custom_parameters) + + updated_query = urllib.urlencode(query_params) + return urlparse.urlunsplit((scheme, netloc, path, updated_query, fragment)) def validate(self): return bool(self.get_login_scopes()) diff --git a/oauth/test/test_oidc.py b/oauth/test/test_oidc.py index 73208cbe4..aaa7d9774 100644 --- a/oauth/test/test_oidc.py +++ b/oauth/test/test_oidc.py @@ -79,6 +79,20 @@ def app_config(http_client, mailing_feature): 'DEBUGGING': True, }, + 'OIDCWITHPARAMS_LOGIN_CONFIG': { + 'CLIENT_ID': 'foo', + 'CLIENT_SECRET': 'bar', + 'SERVICE_NAME': 'Some Other Service', + 'SERVICE_ICON': 'http://some/icon', + 'OIDC_SERVER': 'http://fakeoidc', + 'DEBUGGING': True, + 'OIDC_ENDPOINT_CUSTOM_PARAMS': { + 'authorization_endpoint': { + 'some': 'param', + }, + }, + }, + 'HTTPCLIENT': http_client, } @@ -90,6 +104,10 @@ def oidc_service(app_config): def another_oidc_service(app_config): return OIDCLoginService(app_config, 'ANOTHEROIDC_LOGIN_CONFIG') +@pytest.fixture() +def oidc_withparams_service(app_config): + return OIDCLoginService(app_config, 'OIDCWITHPARAMS_LOGIN_CONFIG') + @pytest.fixture() def discovery_content(userinfo_supported): return { @@ -218,13 +236,18 @@ def test_basic_config(oidc_service): def test_discovery(oidc_service, http_client, discovery_content, discovery_handler): with HTTMock(discovery_handler): - auth = discovery_content['authorization_endpoint'] + '?response_type=code&' + auth = discovery_content['authorization_endpoint'] + '?response_type=code' assert oidc_service.authorize_endpoint() == auth assert oidc_service.token_endpoint() == discovery_content['token_endpoint'] assert oidc_service.user_endpoint() == discovery_content['userinfo_endpoint'] assert set(oidc_service.get_login_scopes()) == set(discovery_content['scopes_supported']) +def test_discovery_with_params(oidc_withparams_service, http_client, discovery_content, discovery_handler): + with HTTMock(discovery_handler): + auth = discovery_content['authorization_endpoint'] + '?response_type=code&some=param' + assert 'some=param' in oidc_withparams_service.authorize_endpoint() + def test_filtered_discovery(another_oidc_service, http_client, discovery_content, discovery_handler): with HTTMock(discovery_handler): assert another_oidc_service.get_login_scopes() == ['openid']