diff --git a/util/config/oauth.py b/util/config/oauth.py index a58a8f69d..35b4b02fb 100644 --- a/util/config/oauth.py +++ b/util/config/oauth.py @@ -5,6 +5,7 @@ import time from cachetools import TTLCache from jwkest.jwk import KEYS +from util.string import slash_join logger = logging.getLogger(__name__) @@ -31,12 +32,6 @@ class OAuthConfig(object): def client_secret(self): return self.config.get('CLIENT_SECRET') - def _get_url(self, endpoint, *args): - for arg in args: - endpoint = urlparse.urljoin(endpoint, arg) - - return endpoint - def get_redirect_uri(self, app_config, redirect_suffix=''): return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'], app_config['SERVER_HOSTNAME'], @@ -95,40 +90,34 @@ class GithubOAuthConfig(OAuthConfig): return [org.lower() for org in allowed] def get_public_url(self, suffix): - return urlparse.urljoin(self._endpoint(), suffix) + return slash_join(self._endpoint(), suffix) def _endpoint(self): - endpoint = self.config.get('GITHUB_ENDPOINT', 'https://github.com') - if not endpoint.endswith('/'): - endpoint = endpoint + '/' - return endpoint + return self.config.get('GITHUB_ENDPOINT', 'https://github.com') def is_enterprise(self): return self._endpoint().find('.github.com') < 0 def authorize_endpoint(self): - return self._get_url(self._endpoint(), '/login/oauth/authorize') + '?' + return slash_join(self._endpoint(), '/login/oauth/authorize') + '?' def token_endpoint(self): - return self._get_url(self._endpoint(), '/login/oauth/access_token') + return slash_join(self._endpoint(), '/login/oauth/access_token') def _api_endpoint(self): - return self.config.get('API_ENDPOINT', self._get_url(self._endpoint(), '/api/v3/')) + return self.config.get('API_ENDPOINT', slash_join(self._endpoint(), '/api/v3/')) def api_endpoint(self): return self._api_endpoint()[0:-1] def user_endpoint(self): - api_endpoint = self._api_endpoint() - return self._get_url(api_endpoint, 'user') + return slash_join(self._api_endpoint(), 'user') def email_endpoint(self): - api_endpoint = self._api_endpoint() - return self._get_url(api_endpoint, 'user/emails') + return slash_join(self._api_endpoint(), 'user/emails') def orgs_endpoint(self): - api_endpoint = self._api_endpoint() - return self._get_url(api_endpoint, 'user/orgs') + return slash_join(self._api_endpoint(), 'user/orgs') def validate_client_id_and_secret(self, http_client, app_config): # First: Verify that the github endpoint is actually Github by checking for the @@ -150,14 +139,13 @@ class GithubOAuthConfig(OAuthConfig): # password: # - If the {client_id, client_secret} pair is invalid in some way, we get a 401 error. # - If the pair is valid, then we get a 404 because the 'foo' token does not exists. - validate_endpoint = self._get_url(api_endpoint, 'applications/%s/tokens/foo' % self.client_id()) + validate_endpoint = slash_join(api_endpoint, 'applications/%s/tokens/foo' % self.client_id()) result = http_client.get(validate_endpoint, auth=(self.client_id(), self.client_secret()), timeout=5) return result.status_code == 404 def validate_organization(self, organization_id, http_client): - api_endpoint = self._api_endpoint() - org_endpoint = self._get_url(api_endpoint, 'orgs/%s' % organization_id.lower()) + org_endpoint = slash_join(self._api_endpoint(), 'orgs/%s' % organization_id.lower()) result = http_client.get(org_endpoint, headers={'Accept': 'application/vnd.github.moondragon+json'}, @@ -221,25 +209,22 @@ class GitLabOAuthConfig(OAuthConfig): super(GitLabOAuthConfig, self).__init__(config, key_name) def _endpoint(self): - endpoint = self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com') - if not endpoint.endswith('/'): - endpoint = endpoint + '/' - return endpoint + return self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com') def api_endpoint(self): return self._endpoint() def get_public_url(self, suffix): - return urlparse.urljoin(self._endpoint(), suffix) + return slash_join(self._endpoint(), suffix) def service_name(self): return 'GitLab' def authorize_endpoint(self): - return self._get_url(self._endpoint(), '/oauth/authorize') + return slash_join(self._endpoint(), '/oauth/authorize') def token_endpoint(self): - return self._get_url(self._endpoint(), '/oauth/token') + return slash_join(self._endpoint(), '/oauth/token') def validate_client_id_and_secret(self, http_client, app_config): url = self.token_endpoint()