diff --git a/oauth/base.py b/oauth/base.py index a699358bd..1e3d451d6 100644 --- a/oauth/base.py +++ b/oauth/base.py @@ -1,5 +1,7 @@ +import copy import logging import urllib +import urlparse from abc import ABCMeta, abstractmethod from six import add_metaclass @@ -8,6 +10,34 @@ from util import get_app_url logger = logging.getLogger(__name__) + +class OAuthEndpoint(object): + def __init__(self, base_url, params=None): + self.base_url = base_url + self.params = params or {} + + def with_param(self, name, value): + params_copy = copy.copy(self.params) + params_copy[name] = value + return OAuthEndpoint(self.base_url, params_copy) + + def with_params(self, parameters): + params_copy = copy.copy(self.params) + params_copy.update(parameters) + return OAuthEndpoint(self.base_url, params_copy) + + def to_url_prefix(self): + prefix = self.to_url() + if self.params: + return prefix + '&' + else: + return prefix + '?' + + def to_url(self): + (scheme, netloc, path, _, fragment) = urlparse.urlsplit(self.base_url) + updated_query = urllib.urlencode(self.params) + return urlparse.urlunsplit((scheme, netloc, path, updated_query, fragment)) + class OAuthExchangeCodeException(Exception): """ Exception raised if a code exchange fails. """ pass @@ -36,12 +66,17 @@ class OAuthService(object): @abstractmethod def token_endpoint(self): - """ The endpoint at which the OAuth code can be exchanged for a token. """ + """ Returns the endpoint at which the OAuth code can be exchanged for a token. """ pass @abstractmethod def user_endpoint(self): - """ The endpoint at which user information can be looked up. """ + """ Returns the endpoint at which user information can be looked up. """ + pass + + @abstractmethod + def authorize_endpoint(self): + """ Returns the for authorization of the OAuth service. """ pass @abstractmethod @@ -49,11 +84,6 @@ class OAuthService(object): """ Performs validation of the client ID and secret, raising an exception on failure. """ pass - @abstractmethod - def authorize_endpoint(self): - """ Endpoint for authorization. """ - pass - def requires_form_encoding(self): """ Returns True if form encoding is necessary for the exchange_code_for_token call. """ return False @@ -86,8 +116,7 @@ class OAuthService(object): 'state': csrf_token, } - authorize_url = '%s%s' % (self.authorize_endpoint(), urllib.urlencode(params)) - return authorize_url + return self.authorize_endpoint().with_params(params).to_url() def get_redirect_uri(self, app_config, redirect_suffix=''): return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'], @@ -104,7 +133,7 @@ class OAuthService(object): 'Authorization': 'Bearer %s' % token, } - got_user = http_client.get(self.user_endpoint(), params=token_param, headers=headers) + got_user = http_client.get(self.user_endpoint().to_url(), params=token_param, headers=headers) if got_user.status_code // 100 != 2: raise OAuthGetUserInfoException('Non-2XX response code for user_info call: %s' % got_user.status_code) @@ -148,7 +177,7 @@ class OAuthService(object): payload['client_id'] = self.client_id() payload['client_secret'] = self.client_secret() - token_url = self.token_endpoint() + token_url = self.token_endpoint().to_url() if form_encode: get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth) else: diff --git a/oauth/oidc.py b/oauth/oidc.py index 232ef68ae..e6292530f 100644 --- a/oauth/oidc.py +++ b/oauth/oidc.py @@ -2,7 +2,6 @@ import time import json import logging import urlparse -import urllib import jwt @@ -12,7 +11,8 @@ from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.serialization import load_der_public_key from jwkest.jwk import KEYS -from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException +from oauth.base import (OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException, + OAuthEndpoint) from oauth.login import OAuthLoginException from util.security.jwtutil import decode, InvalidTokenError @@ -66,7 +66,7 @@ class OIDCLoginService(OAuthService): return list(set(login_scopes) & set(supported_scopes)) def authorize_endpoint(self): - return self._get_endpoint('authorization_endpoint', response_type='code') + return self._get_endpoint('authorization_endpoint').with_param('response_type', 'code') def token_endpoint(self): return self._get_endpoint('token_endpoint') @@ -92,16 +92,14 @@ class OIDCLoginService(OAuthService): query_params = urlparse.parse_qs(query, keep_blank_values=True) query_params.update(kwargs) query_params.update(custom_parameters) - - updated_query = urllib.urlencode(query_params) - return urlparse.urlunsplit((scheme, netloc, path, updated_query, fragment)) + return OAuthEndpoint(urlparse.urlunsplit((scheme, netloc, path, {}, fragment)), query_params) def validate(self): return bool(self.get_login_scopes()) def validate_client_id_and_secret(self, http_client, app_config): # TODO: find a way to verify client secret too. - check_auth_url = http_client.get(self.get_auth_url()) + check_auth_url = http_client.get(self.get_auth_url(app_config, '', '', [])) if check_auth_url.status_code // 100 != 2: raise Exception('Got non-200 status code for authorization endpoint') diff --git a/oauth/services/github.py b/oauth/services/github.py index da25775f7..3923b6c95 100644 --- a/oauth/services/github.py +++ b/oauth/services/github.py @@ -1,5 +1,6 @@ import logging +from oauth.base import OAuthEndpoint from oauth.login import OAuthLoginService, OAuthLoginException from util import slash_join @@ -50,10 +51,13 @@ class GithubOAuthService(OAuthLoginService): return self._api_endpoint().find('.github.com') < 0 def authorize_endpoint(self): - return slash_join(self._endpoint(), '/login/oauth/authorize') + '?' + return OAuthEndpoint(slash_join(self._endpoint(), '/login/oauth/authorize')) def token_endpoint(self): - return slash_join(self._endpoint(), '/login/oauth/access_token') + return OAuthEndpoint(slash_join(self._endpoint(), '/login/oauth/access_token')) + + def user_endpoint(self): + return OAuthEndpoint(slash_join(self._api_endpoint(), 'user')) def _api_endpoint(self): return self.config.get('API_ENDPOINT', slash_join(self._endpoint(), '/api/v3/')) @@ -65,9 +69,6 @@ class GithubOAuthService(OAuthLoginService): return endpoint - def user_endpoint(self): - return slash_join(self._api_endpoint(), 'user') - def email_endpoint(self): return slash_join(self._api_endpoint(), 'user/emails') @@ -112,7 +113,7 @@ class GithubOAuthService(OAuthLoginService): def get_public_config(self): return { 'CLIENT_ID': self.client_id(), - 'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), + 'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url_prefix(), 'GITHUB_ENDPOINT': self._endpoint(), 'ORG_RESTRICT': self.config.get('ORG_RESTRICT', False) } diff --git a/oauth/services/gitlab.py b/oauth/services/gitlab.py index f05d9c617..4ac0dda22 100644 --- a/oauth/services/gitlab.py +++ b/oauth/services/gitlab.py @@ -1,4 +1,4 @@ -from oauth.base import OAuthService +from oauth.base import OAuthService, OAuthEndpoint from util import slash_join class GitLabOAuthService(OAuthService): @@ -24,17 +24,17 @@ class GitLabOAuthService(OAuthService): return slash_join(self._endpoint(), suffix) def authorize_endpoint(self): - return slash_join(self._endpoint(), '/oauth/authorize') + return OAuthEndpoint(slash_join(self._endpoint(), '/oauth/authorize')) def token_endpoint(self): - return slash_join(self._endpoint(), '/oauth/token') + return OAuthEndpoint(slash_join(self._endpoint(), '/oauth/token')) def validate_client_id_and_secret(self, http_client, app_config): # We validate the client ID and secret by hitting the OAuth token exchange endpoint with # the real client ID and secret, but a fake auth code to exchange. Gitlab's implementation will # return `invalid_client` as the `error` if the client ID or secret is invalid; otherwise, it # will return another error. - url = self.token_endpoint() + url = self.token_endpoint().to_url() redirect_uri = self.get_redirect_uri(app_config, redirect_suffix='trigger') data = { 'code': 'fakecode', @@ -55,6 +55,6 @@ class GitLabOAuthService(OAuthService): def get_public_config(self): return { 'CLIENT_ID': self.client_id(), - 'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), + 'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url_prefix(), 'GITLAB_ENDPOINT': self._endpoint(), } diff --git a/oauth/services/google.py b/oauth/services/google.py index aaedfbac0..ede5203dd 100644 --- a/oauth/services/google.py +++ b/oauth/services/google.py @@ -1,3 +1,4 @@ +from oauth.base import OAuthEndpoint from oauth.login import OAuthLoginService def _get_email_username(email_address): @@ -28,13 +29,14 @@ class GoogleOAuthService(OAuthLoginService): return ['openid', 'email'] def authorize_endpoint(self): - return 'https://accounts.google.com/o/oauth2/auth?response_type=code&' + return OAuthEndpoint('https://accounts.google.com/o/oauth2/auth', + params=dict(response_type='code')) def token_endpoint(self): - return 'https://accounts.google.com/o/oauth2/token' + return OAuthEndpoint('https://accounts.google.com/o/oauth2/token') def user_endpoint(self): - return 'https://www.googleapis.com/oauth2/v1/userinfo' + return OAuthEndpoint('https://www.googleapis.com/oauth2/v1/userinfo') def requires_form_encoding(self): return True @@ -59,7 +61,7 @@ class GoogleOAuthService(OAuthLoginService): def get_public_config(self): return { 'CLIENT_ID': self.client_id(), - 'AUTHORIZE_ENDPOINT': self.authorize_endpoint() + 'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url_prefix() } def get_login_service_id(self, user_info): diff --git a/oauth/services/test/test_github.py b/oauth/services/test/test_github.py new file mode 100644 index 000000000..c19ec3f42 --- /dev/null +++ b/oauth/services/test/test_github.py @@ -0,0 +1,39 @@ +import pytest + +from oauth.services.github import GithubOAuthService + +@pytest.mark.parametrize('trigger_config, domain, api_endpoint, is_enterprise', [ + ({ + 'CLIENT_ID': 'someclientid', + 'CLIENT_SECRET': 'someclientsecret', + 'API_ENDPOINT': 'https://api.github.com/v3', + }, 'https://github.com', 'https://api.github.com/v3', False), + ({ + 'GITHUB_ENDPOINT': 'https://github.somedomain.com/', + 'CLIENT_ID': 'someclientid', + 'CLIENT_SECRET': 'someclientsecret', + }, 'https://github.somedomain.com', 'https://github.somedomain.com/api/v3', True), + ({ + 'GITHUB_ENDPOINT': 'https://github.somedomain.com/', + 'API_ENDPOINT': 'http://somedomain.com/api/', + 'CLIENT_ID': 'someclientid', + 'CLIENT_SECRET': 'someclientsecret', + }, 'https://github.somedomain.com', 'http://somedomain.com/api', True), +]) +def test_basic_enterprise_config(trigger_config, domain, api_endpoint, is_enterprise): + config = { + 'GITHUB_TRIGGER_CONFIG': trigger_config + } + + github_trigger = GithubOAuthService(config, 'GITHUB_TRIGGER_CONFIG') + assert github_trigger.is_enterprise() == is_enterprise + + assert github_trigger.authorize_endpoint().to_url() == '%s/login/oauth/authorize' % domain + assert github_trigger.authorize_endpoint().to_url_prefix() == '%s/login/oauth/authorize?' % domain + + assert github_trigger.token_endpoint().to_url() == '%s/login/oauth/access_token' % domain + + assert github_trigger.api_endpoint() == api_endpoint + assert github_trigger.user_endpoint().to_url() == '%s/user' % api_endpoint + assert github_trigger.email_endpoint() == '%s/user/emails' % api_endpoint + assert github_trigger.orgs_endpoint() == '%s/user/orgs' % api_endpoint diff --git a/oauth/test/test_oidc.py b/oauth/test/test_oidc.py index aaa7d9774..bd57defe2 100644 --- a/oauth/test/test_oidc.py +++ b/oauth/test/test_oidc.py @@ -154,6 +154,16 @@ def discovery_handler(discovery_content): return handler +@pytest.fixture() +def authorize_handler(discovery_content): + @urlmatch(netloc=r'fakeoidc', path=r'/authorize') + def handler(_, request): + parsed = urlparse.urlparse(request.url) + params = urlparse.parse_qs(parsed.query) + return json.dumps({'authorized': True, 'scope': params['scope'][0], 'state': params['state'][0]}) + + return handler + @pytest.fixture() def token_handler(oidc_service, id_token, valid_code): @urlmatch(netloc=r'fakeoidc', path=r'/token') @@ -237,16 +247,19 @@ def test_basic_config(oidc_service): def test_discovery(oidc_service, http_client, discovery_content, discovery_handler): with HTTMock(discovery_handler): auth = discovery_content['authorization_endpoint'] + '?response_type=code' - assert oidc_service.authorize_endpoint() == auth + assert oidc_service.authorize_endpoint().to_url() == auth + assert oidc_service.token_endpoint().to_url() == discovery_content['token_endpoint'] + + if discovery_content['userinfo_endpoint'] is None: + assert oidc_service.user_endpoint() is None + else: + assert oidc_service.user_endpoint().to_url() == discovery_content['userinfo_endpoint'] - assert oidc_service.token_endpoint() == discovery_content['token_endpoint'] - assert oidc_service.user_endpoint() == discovery_content['userinfo_endpoint'] assert set(oidc_service.get_login_scopes()) == set(discovery_content['scopes_supported']) def test_discovery_with_params(oidc_withparams_service, http_client, discovery_content, discovery_handler): with HTTMock(discovery_handler): - auth = discovery_content['authorization_endpoint'] + '?response_type=code&some=param' - assert 'some=param' in oidc_withparams_service.authorize_endpoint() + assert 'some=param' in oidc_withparams_service.authorize_endpoint().to_url() def test_filtered_discovery(another_oidc_service, http_client, discovery_content, discovery_handler): with HTTMock(discovery_handler): @@ -260,6 +273,17 @@ def test_public_config(oidc_service, discovery_handler): assert 'CLIENT_SECRET' not in oidc_service.get_public_config() assert 'bar' not in oidc_service.get_public_config().values() +def test_auth_url(oidc_service, discovery_handler, http_client, authorize_handler): + config = {'PREFERRED_URL_SCHEME': 'https', 'SERVER_HOSTNAME': 'someserver'} + + with HTTMock(discovery_handler, authorize_handler): + auth_url = oidc_service.get_auth_url(config, '', 'some csrf token', ['one', 'two']) + + # Hit the URL and ensure it works. + result = http_client.get(auth_url).json() + assert result['state'] == 'some csrf token' + assert result['scope'] == 'one two' + def test_exchange_code_invalidcode(oidc_service, discovery_handler, app_config, http_client, token_handler): with HTTMock(token_handler, discovery_handler): diff --git a/test/test_github.py b/test/test_github.py deleted file mode 100644 index 1296af155..000000000 --- a/test/test_github.py +++ /dev/null @@ -1,48 +0,0 @@ -import unittest - -from oauth.services.github import GithubOAuthService - -class TestGithub(unittest.TestCase): - def test_basic_enterprise_config(self): - config = { - 'GITHUB_TRIGGER_CONFIG': { - 'GITHUB_ENDPOINT': 'https://github.somedomain.com/', - 'CLIENT_ID': 'someclientid', - 'CLIENT_SECRET': 'someclientsecret', - } - } - - github_trigger = GithubOAuthService(config, 'GITHUB_TRIGGER_CONFIG') - self.assertTrue(github_trigger.is_enterprise()) - self.assertEquals('https://github.somedomain.com/login/oauth/authorize?', github_trigger.authorize_endpoint()) - self.assertEquals('https://github.somedomain.com/login/oauth/access_token', github_trigger.token_endpoint()) - - self.assertEquals('https://github.somedomain.com/api/v3', github_trigger.api_endpoint()) - - self.assertEquals('https://github.somedomain.com/api/v3/user', github_trigger.user_endpoint()) - self.assertEquals('https://github.somedomain.com/api/v3/user/emails', github_trigger.email_endpoint()) - self.assertEquals('https://github.somedomain.com/api/v3/user/orgs', github_trigger.orgs_endpoint()) - - def test_custom_enterprise_config(self): - config = { - 'GITHUB_TRIGGER_CONFIG': { - 'GITHUB_ENDPOINT': 'https://github.somedomain.com/', - 'API_ENDPOINT': 'http://somedomain.com/api', - 'CLIENT_ID': 'someclientid', - 'CLIENT_SECRET': 'someclientsecret', - } - } - - github_trigger = GithubOAuthService(config, 'GITHUB_TRIGGER_CONFIG') - self.assertTrue(github_trigger.is_enterprise()) - self.assertEquals('https://github.somedomain.com/login/oauth/authorize?', github_trigger.authorize_endpoint()) - self.assertEquals('https://github.somedomain.com/login/oauth/access_token', github_trigger.token_endpoint()) - - self.assertEquals('http://somedomain.com/api', github_trigger.api_endpoint()) - - self.assertEquals('http://somedomain.com/api/user', github_trigger.user_endpoint()) - self.assertEquals('http://somedomain.com/api/user/emails', github_trigger.email_endpoint()) - self.assertEquals('http://somedomain.com/api/user/orgs', github_trigger.orgs_endpoint()) - -if __name__ == '__main__': - unittest.main()