Fix bug with missing & in authorization URL for OIDC

Also adds testing to ensure we don't break this again
This commit is contained in:
Joseph Schorr 2018-05-15 13:28:43 -04:00
parent 4c0ab81ac8
commit 22a39c3007
8 changed files with 131 additions and 86 deletions

View file

@ -1,5 +1,7 @@
import copy
import logging import logging
import urllib import urllib
import urlparse
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from six import add_metaclass from six import add_metaclass
@ -8,6 +10,34 @@ from util import get_app_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class OAuthEndpoint(object):
def __init__(self, base_url, params=None):
self.base_url = base_url
self.params = params or {}
def with_param(self, name, value):
params_copy = copy.copy(self.params)
params_copy[name] = value
return OAuthEndpoint(self.base_url, params_copy)
def with_params(self, parameters):
params_copy = copy.copy(self.params)
params_copy.update(parameters)
return OAuthEndpoint(self.base_url, params_copy)
def to_url_prefix(self):
prefix = self.to_url()
if self.params:
return prefix + '&'
else:
return prefix + '?'
def to_url(self):
(scheme, netloc, path, _, fragment) = urlparse.urlsplit(self.base_url)
updated_query = urllib.urlencode(self.params)
return urlparse.urlunsplit((scheme, netloc, path, updated_query, fragment))
class OAuthExchangeCodeException(Exception): class OAuthExchangeCodeException(Exception):
""" Exception raised if a code exchange fails. """ """ Exception raised if a code exchange fails. """
pass pass
@ -36,12 +66,17 @@ class OAuthService(object):
@abstractmethod @abstractmethod
def token_endpoint(self): def token_endpoint(self):
""" The endpoint at which the OAuth code can be exchanged for a token. """ """ Returns the endpoint at which the OAuth code can be exchanged for a token. """
pass pass
@abstractmethod @abstractmethod
def user_endpoint(self): def user_endpoint(self):
""" The endpoint at which user information can be looked up. """ """ Returns the endpoint at which user information can be looked up. """
pass
@abstractmethod
def authorize_endpoint(self):
""" Returns the for authorization of the OAuth service. """
pass pass
@abstractmethod @abstractmethod
@ -49,11 +84,6 @@ class OAuthService(object):
""" Performs validation of the client ID and secret, raising an exception on failure. """ """ Performs validation of the client ID and secret, raising an exception on failure. """
pass pass
@abstractmethod
def authorize_endpoint(self):
""" Endpoint for authorization. """
pass
def requires_form_encoding(self): def requires_form_encoding(self):
""" Returns True if form encoding is necessary for the exchange_code_for_token call. """ """ Returns True if form encoding is necessary for the exchange_code_for_token call. """
return False return False
@ -86,8 +116,7 @@ class OAuthService(object):
'state': csrf_token, 'state': csrf_token,
} }
authorize_url = '%s%s' % (self.authorize_endpoint(), urllib.urlencode(params)) return self.authorize_endpoint().with_params(params).to_url()
return authorize_url
def get_redirect_uri(self, app_config, redirect_suffix=''): def get_redirect_uri(self, app_config, redirect_suffix=''):
return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'], return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
@ -104,7 +133,7 @@ class OAuthService(object):
'Authorization': 'Bearer %s' % token, 'Authorization': 'Bearer %s' % token,
} }
got_user = http_client.get(self.user_endpoint(), params=token_param, headers=headers) got_user = http_client.get(self.user_endpoint().to_url(), params=token_param, headers=headers)
if got_user.status_code // 100 != 2: if got_user.status_code // 100 != 2:
raise OAuthGetUserInfoException('Non-2XX response code for user_info call: %s' % raise OAuthGetUserInfoException('Non-2XX response code for user_info call: %s' %
got_user.status_code) got_user.status_code)
@ -148,7 +177,7 @@ class OAuthService(object):
payload['client_id'] = self.client_id() payload['client_id'] = self.client_id()
payload['client_secret'] = self.client_secret() payload['client_secret'] = self.client_secret()
token_url = self.token_endpoint() token_url = self.token_endpoint().to_url()
if form_encode: if form_encode:
get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth) get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
else: else:

View file

@ -2,7 +2,6 @@ import time
import json import json
import logging import logging
import urlparse import urlparse
import urllib
import jwt import jwt
@ -12,7 +11,8 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key from cryptography.hazmat.primitives.serialization import load_der_public_key
from jwkest.jwk import KEYS from jwkest.jwk import KEYS
from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException from oauth.base import (OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException,
OAuthEndpoint)
from oauth.login import OAuthLoginException from oauth.login import OAuthLoginException
from util.security.jwtutil import decode, InvalidTokenError from util.security.jwtutil import decode, InvalidTokenError
@ -66,7 +66,7 @@ class OIDCLoginService(OAuthService):
return list(set(login_scopes) & set(supported_scopes)) return list(set(login_scopes) & set(supported_scopes))
def authorize_endpoint(self): def authorize_endpoint(self):
return self._get_endpoint('authorization_endpoint', response_type='code') return self._get_endpoint('authorization_endpoint').with_param('response_type', 'code')
def token_endpoint(self): def token_endpoint(self):
return self._get_endpoint('token_endpoint') return self._get_endpoint('token_endpoint')
@ -92,16 +92,14 @@ class OIDCLoginService(OAuthService):
query_params = urlparse.parse_qs(query, keep_blank_values=True) query_params = urlparse.parse_qs(query, keep_blank_values=True)
query_params.update(kwargs) query_params.update(kwargs)
query_params.update(custom_parameters) query_params.update(custom_parameters)
return OAuthEndpoint(urlparse.urlunsplit((scheme, netloc, path, {}, fragment)), query_params)
updated_query = urllib.urlencode(query_params)
return urlparse.urlunsplit((scheme, netloc, path, updated_query, fragment))
def validate(self): def validate(self):
return bool(self.get_login_scopes()) return bool(self.get_login_scopes())
def validate_client_id_and_secret(self, http_client, app_config): def validate_client_id_and_secret(self, http_client, app_config):
# TODO: find a way to verify client secret too. # TODO: find a way to verify client secret too.
check_auth_url = http_client.get(self.get_auth_url()) check_auth_url = http_client.get(self.get_auth_url(app_config, '', '', []))
if check_auth_url.status_code // 100 != 2: if check_auth_url.status_code // 100 != 2:
raise Exception('Got non-200 status code for authorization endpoint') raise Exception('Got non-200 status code for authorization endpoint')

View file

@ -1,5 +1,6 @@
import logging import logging
from oauth.base import OAuthEndpoint
from oauth.login import OAuthLoginService, OAuthLoginException from oauth.login import OAuthLoginService, OAuthLoginException
from util import slash_join from util import slash_join
@ -50,10 +51,13 @@ class GithubOAuthService(OAuthLoginService):
return self._api_endpoint().find('.github.com') < 0 return self._api_endpoint().find('.github.com') < 0
def authorize_endpoint(self): def authorize_endpoint(self):
return slash_join(self._endpoint(), '/login/oauth/authorize') + '?' return OAuthEndpoint(slash_join(self._endpoint(), '/login/oauth/authorize'))
def token_endpoint(self): def token_endpoint(self):
return slash_join(self._endpoint(), '/login/oauth/access_token') return OAuthEndpoint(slash_join(self._endpoint(), '/login/oauth/access_token'))
def user_endpoint(self):
return OAuthEndpoint(slash_join(self._api_endpoint(), 'user'))
def _api_endpoint(self): def _api_endpoint(self):
return self.config.get('API_ENDPOINT', slash_join(self._endpoint(), '/api/v3/')) return self.config.get('API_ENDPOINT', slash_join(self._endpoint(), '/api/v3/'))
@ -65,9 +69,6 @@ class GithubOAuthService(OAuthLoginService):
return endpoint return endpoint
def user_endpoint(self):
return slash_join(self._api_endpoint(), 'user')
def email_endpoint(self): def email_endpoint(self):
return slash_join(self._api_endpoint(), 'user/emails') return slash_join(self._api_endpoint(), 'user/emails')
@ -112,7 +113,7 @@ class GithubOAuthService(OAuthLoginService):
def get_public_config(self): def get_public_config(self):
return { return {
'CLIENT_ID': self.client_id(), 'CLIENT_ID': self.client_id(),
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), 'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url_prefix(),
'GITHUB_ENDPOINT': self._endpoint(), 'GITHUB_ENDPOINT': self._endpoint(),
'ORG_RESTRICT': self.config.get('ORG_RESTRICT', False) 'ORG_RESTRICT': self.config.get('ORG_RESTRICT', False)
} }

View file

@ -1,4 +1,4 @@
from oauth.base import OAuthService from oauth.base import OAuthService, OAuthEndpoint
from util import slash_join from util import slash_join
class GitLabOAuthService(OAuthService): class GitLabOAuthService(OAuthService):
@ -24,17 +24,17 @@ class GitLabOAuthService(OAuthService):
return slash_join(self._endpoint(), suffix) return slash_join(self._endpoint(), suffix)
def authorize_endpoint(self): def authorize_endpoint(self):
return slash_join(self._endpoint(), '/oauth/authorize') return OAuthEndpoint(slash_join(self._endpoint(), '/oauth/authorize'))
def token_endpoint(self): def token_endpoint(self):
return slash_join(self._endpoint(), '/oauth/token') return OAuthEndpoint(slash_join(self._endpoint(), '/oauth/token'))
def validate_client_id_and_secret(self, http_client, app_config): def validate_client_id_and_secret(self, http_client, app_config):
# We validate the client ID and secret by hitting the OAuth token exchange endpoint with # We validate the client ID and secret by hitting the OAuth token exchange endpoint with
# the real client ID and secret, but a fake auth code to exchange. Gitlab's implementation will # the real client ID and secret, but a fake auth code to exchange. Gitlab's implementation will
# return `invalid_client` as the `error` if the client ID or secret is invalid; otherwise, it # return `invalid_client` as the `error` if the client ID or secret is invalid; otherwise, it
# will return another error. # will return another error.
url = self.token_endpoint() url = self.token_endpoint().to_url()
redirect_uri = self.get_redirect_uri(app_config, redirect_suffix='trigger') redirect_uri = self.get_redirect_uri(app_config, redirect_suffix='trigger')
data = { data = {
'code': 'fakecode', 'code': 'fakecode',
@ -55,6 +55,6 @@ class GitLabOAuthService(OAuthService):
def get_public_config(self): def get_public_config(self):
return { return {
'CLIENT_ID': self.client_id(), 'CLIENT_ID': self.client_id(),
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), 'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url_prefix(),
'GITLAB_ENDPOINT': self._endpoint(), 'GITLAB_ENDPOINT': self._endpoint(),
} }

View file

@ -1,3 +1,4 @@
from oauth.base import OAuthEndpoint
from oauth.login import OAuthLoginService from oauth.login import OAuthLoginService
def _get_email_username(email_address): def _get_email_username(email_address):
@ -28,13 +29,14 @@ class GoogleOAuthService(OAuthLoginService):
return ['openid', 'email'] return ['openid', 'email']
def authorize_endpoint(self): def authorize_endpoint(self):
return 'https://accounts.google.com/o/oauth2/auth?response_type=code&' return OAuthEndpoint('https://accounts.google.com/o/oauth2/auth',
params=dict(response_type='code'))
def token_endpoint(self): def token_endpoint(self):
return 'https://accounts.google.com/o/oauth2/token' return OAuthEndpoint('https://accounts.google.com/o/oauth2/token')
def user_endpoint(self): def user_endpoint(self):
return 'https://www.googleapis.com/oauth2/v1/userinfo' return OAuthEndpoint('https://www.googleapis.com/oauth2/v1/userinfo')
def requires_form_encoding(self): def requires_form_encoding(self):
return True return True
@ -59,7 +61,7 @@ class GoogleOAuthService(OAuthLoginService):
def get_public_config(self): def get_public_config(self):
return { return {
'CLIENT_ID': self.client_id(), 'CLIENT_ID': self.client_id(),
'AUTHORIZE_ENDPOINT': self.authorize_endpoint() 'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url_prefix()
} }
def get_login_service_id(self, user_info): def get_login_service_id(self, user_info):

View file

@ -0,0 +1,39 @@
import pytest
from oauth.services.github import GithubOAuthService
@pytest.mark.parametrize('trigger_config, domain, api_endpoint, is_enterprise', [
({
'CLIENT_ID': 'someclientid',
'CLIENT_SECRET': 'someclientsecret',
'API_ENDPOINT': 'https://api.github.com/v3',
}, 'https://github.com', 'https://api.github.com/v3', False),
({
'GITHUB_ENDPOINT': 'https://github.somedomain.com/',
'CLIENT_ID': 'someclientid',
'CLIENT_SECRET': 'someclientsecret',
}, 'https://github.somedomain.com', 'https://github.somedomain.com/api/v3', True),
({
'GITHUB_ENDPOINT': 'https://github.somedomain.com/',
'API_ENDPOINT': 'http://somedomain.com/api/',
'CLIENT_ID': 'someclientid',
'CLIENT_SECRET': 'someclientsecret',
}, 'https://github.somedomain.com', 'http://somedomain.com/api', True),
])
def test_basic_enterprise_config(trigger_config, domain, api_endpoint, is_enterprise):
config = {
'GITHUB_TRIGGER_CONFIG': trigger_config
}
github_trigger = GithubOAuthService(config, 'GITHUB_TRIGGER_CONFIG')
assert github_trigger.is_enterprise() == is_enterprise
assert github_trigger.authorize_endpoint().to_url() == '%s/login/oauth/authorize' % domain
assert github_trigger.authorize_endpoint().to_url_prefix() == '%s/login/oauth/authorize?' % domain
assert github_trigger.token_endpoint().to_url() == '%s/login/oauth/access_token' % domain
assert github_trigger.api_endpoint() == api_endpoint
assert github_trigger.user_endpoint().to_url() == '%s/user' % api_endpoint
assert github_trigger.email_endpoint() == '%s/user/emails' % api_endpoint
assert github_trigger.orgs_endpoint() == '%s/user/orgs' % api_endpoint

View file

@ -154,6 +154,16 @@ def discovery_handler(discovery_content):
return handler return handler
@pytest.fixture()
def authorize_handler(discovery_content):
@urlmatch(netloc=r'fakeoidc', path=r'/authorize')
def handler(_, request):
parsed = urlparse.urlparse(request.url)
params = urlparse.parse_qs(parsed.query)
return json.dumps({'authorized': True, 'scope': params['scope'][0], 'state': params['state'][0]})
return handler
@pytest.fixture() @pytest.fixture()
def token_handler(oidc_service, id_token, valid_code): def token_handler(oidc_service, id_token, valid_code):
@urlmatch(netloc=r'fakeoidc', path=r'/token') @urlmatch(netloc=r'fakeoidc', path=r'/token')
@ -237,16 +247,19 @@ def test_basic_config(oidc_service):
def test_discovery(oidc_service, http_client, discovery_content, discovery_handler): def test_discovery(oidc_service, http_client, discovery_content, discovery_handler):
with HTTMock(discovery_handler): with HTTMock(discovery_handler):
auth = discovery_content['authorization_endpoint'] + '?response_type=code' auth = discovery_content['authorization_endpoint'] + '?response_type=code'
assert oidc_service.authorize_endpoint() == auth assert oidc_service.authorize_endpoint().to_url() == auth
assert oidc_service.token_endpoint().to_url() == discovery_content['token_endpoint']
if discovery_content['userinfo_endpoint'] is None:
assert oidc_service.user_endpoint() is None
else:
assert oidc_service.user_endpoint().to_url() == discovery_content['userinfo_endpoint']
assert oidc_service.token_endpoint() == discovery_content['token_endpoint']
assert oidc_service.user_endpoint() == discovery_content['userinfo_endpoint']
assert set(oidc_service.get_login_scopes()) == set(discovery_content['scopes_supported']) assert set(oidc_service.get_login_scopes()) == set(discovery_content['scopes_supported'])
def test_discovery_with_params(oidc_withparams_service, http_client, discovery_content, discovery_handler): def test_discovery_with_params(oidc_withparams_service, http_client, discovery_content, discovery_handler):
with HTTMock(discovery_handler): with HTTMock(discovery_handler):
auth = discovery_content['authorization_endpoint'] + '?response_type=code&some=param' assert 'some=param' in oidc_withparams_service.authorize_endpoint().to_url()
assert 'some=param' in oidc_withparams_service.authorize_endpoint()
def test_filtered_discovery(another_oidc_service, http_client, discovery_content, discovery_handler): def test_filtered_discovery(another_oidc_service, http_client, discovery_content, discovery_handler):
with HTTMock(discovery_handler): with HTTMock(discovery_handler):
@ -260,6 +273,17 @@ def test_public_config(oidc_service, discovery_handler):
assert 'CLIENT_SECRET' not in oidc_service.get_public_config() assert 'CLIENT_SECRET' not in oidc_service.get_public_config()
assert 'bar' not in oidc_service.get_public_config().values() assert 'bar' not in oidc_service.get_public_config().values()
def test_auth_url(oidc_service, discovery_handler, http_client, authorize_handler):
config = {'PREFERRED_URL_SCHEME': 'https', 'SERVER_HOSTNAME': 'someserver'}
with HTTMock(discovery_handler, authorize_handler):
auth_url = oidc_service.get_auth_url(config, '', 'some csrf token', ['one', 'two'])
# Hit the URL and ensure it works.
result = http_client.get(auth_url).json()
assert result['state'] == 'some csrf token'
assert result['scope'] == 'one two'
def test_exchange_code_invalidcode(oidc_service, discovery_handler, app_config, http_client, def test_exchange_code_invalidcode(oidc_service, discovery_handler, app_config, http_client,
token_handler): token_handler):
with HTTMock(token_handler, discovery_handler): with HTTMock(token_handler, discovery_handler):

View file

@ -1,48 +0,0 @@
import unittest
from oauth.services.github import GithubOAuthService
class TestGithub(unittest.TestCase):
def test_basic_enterprise_config(self):
config = {
'GITHUB_TRIGGER_CONFIG': {
'GITHUB_ENDPOINT': 'https://github.somedomain.com/',
'CLIENT_ID': 'someclientid',
'CLIENT_SECRET': 'someclientsecret',
}
}
github_trigger = GithubOAuthService(config, 'GITHUB_TRIGGER_CONFIG')
self.assertTrue(github_trigger.is_enterprise())
self.assertEquals('https://github.somedomain.com/login/oauth/authorize?', github_trigger.authorize_endpoint())
self.assertEquals('https://github.somedomain.com/login/oauth/access_token', github_trigger.token_endpoint())
self.assertEquals('https://github.somedomain.com/api/v3', github_trigger.api_endpoint())
self.assertEquals('https://github.somedomain.com/api/v3/user', github_trigger.user_endpoint())
self.assertEquals('https://github.somedomain.com/api/v3/user/emails', github_trigger.email_endpoint())
self.assertEquals('https://github.somedomain.com/api/v3/user/orgs', github_trigger.orgs_endpoint())
def test_custom_enterprise_config(self):
config = {
'GITHUB_TRIGGER_CONFIG': {
'GITHUB_ENDPOINT': 'https://github.somedomain.com/',
'API_ENDPOINT': 'http://somedomain.com/api',
'CLIENT_ID': 'someclientid',
'CLIENT_SECRET': 'someclientsecret',
}
}
github_trigger = GithubOAuthService(config, 'GITHUB_TRIGGER_CONFIG')
self.assertTrue(github_trigger.is_enterprise())
self.assertEquals('https://github.somedomain.com/login/oauth/authorize?', github_trigger.authorize_endpoint())
self.assertEquals('https://github.somedomain.com/login/oauth/access_token', github_trigger.token_endpoint())
self.assertEquals('http://somedomain.com/api', github_trigger.api_endpoint())
self.assertEquals('http://somedomain.com/api/user', github_trigger.user_endpoint())
self.assertEquals('http://somedomain.com/api/user/emails', github_trigger.email_endpoint())
self.assertEquals('http://somedomain.com/api/user/orgs', github_trigger.orgs_endpoint())
if __name__ == '__main__':
unittest.main()