import urlparse import json import logging import time from cachetools import TTLCache from jwkest.jwk import KEYS from util import slash_join logger = logging.getLogger(__name__) class OAuthConfig(object): def __init__(self, config, key_name): self.key_name = key_name self.config = config.get(key_name) or {} def service_name(self): raise NotImplementedError def token_endpoint(self): raise NotImplementedError def user_endpoint(self): raise NotImplementedError def validate_client_id_and_secret(self, http_client, app_config): raise NotImplementedError def client_id(self): return self.config.get('CLIENT_ID') def client_secret(self): return self.config.get('CLIENT_SECRET') def get_redirect_uri(self, app_config, redirect_suffix=''): return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'], app_config['SERVER_HOSTNAME'], self.service_name().lower(), redirect_suffix) def exchange_code_for_token(self, app_config, http_client, code, form_encode=False, redirect_suffix='', client_auth=False): payload = { 'code': code, 'grant_type': 'authorization_code', 'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix) } headers = { 'Accept': 'application/json' } auth = None if client_auth: auth = (self.client_id(), self.client_secret()) else: payload['client_id'] = self.client_id() payload['client_secret'] = self.client_secret() token_url = self.token_endpoint() if form_encode: get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth) else: get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth) json_data = get_access_token.json() if not json_data: return '' token = json_data.get('access_token', '') return token class GithubOAuthConfig(OAuthConfig): def __init__(self, config, key_name): super(GithubOAuthConfig, self).__init__(config, key_name) def service_name(self): return 'GitHub' def allowed_organizations(self): if not self.config.get('ORG_RESTRICT', False): return None allowed = self.config.get('ALLOWED_ORGANIZATIONS', None) if allowed is None: return None return [org.lower() for org in allowed] def get_public_url(self, suffix): return slash_join(self._endpoint(), suffix) def _endpoint(self): return self.config.get('GITHUB_ENDPOINT', 'https://github.com') def is_enterprise(self): return self._endpoint().find('.github.com') < 0 def authorize_endpoint(self): return slash_join(self._endpoint(), '/login/oauth/authorize') + '?' def token_endpoint(self): return slash_join(self._endpoint(), '/login/oauth/access_token') def _api_endpoint(self): return self.config.get('API_ENDPOINT', slash_join(self._endpoint(), '/api/v3/')) def api_endpoint(self): endpoint = self._api_endpoint() if endpoint.endswith('/'): return endpoint[0:-1] return endpoint def user_endpoint(self): return slash_join(self._api_endpoint(), 'user') def email_endpoint(self): return slash_join(self._api_endpoint(), 'user/emails') def orgs_endpoint(self): return slash_join(self._api_endpoint(), 'user/orgs') def validate_client_id_and_secret(self, http_client, app_config): # First: Verify that the github endpoint is actually Github by checking for the # X-GitHub-Request-Id here. api_endpoint = self._api_endpoint() result = http_client.get(api_endpoint, auth=(self.client_id(), self.client_secret()), timeout=5) if not 'X-GitHub-Request-Id' in result.headers: raise Exception('Endpoint is not a Github (Enterprise) installation') # Next: Verify the client ID and secret. # Note: The following code is a hack until such time as Github officially adds an API endpoint # for verifying a {client_id, client_secret} pair. That being said, this hack was given to us # *by a Github Engineer*, so I think it is okay for the time being :) # # TODO(jschorr): Replace with the real API call once added. # # Hitting the endpoint applications/{client_id}/tokens/foo will result in the following # behavior IF the client_id is given as the HTTP username and the client_secret as the HTTP # password: # - If the {client_id, client_secret} pair is invalid in some way, we get a 401 error. # - If the pair is valid, then we get a 404 because the 'foo' token does not exists. validate_endpoint = slash_join(api_endpoint, 'applications/%s/tokens/foo' % self.client_id()) result = http_client.get(validate_endpoint, auth=(self.client_id(), self.client_secret()), timeout=5) return result.status_code == 404 def validate_organization(self, organization_id, http_client): org_endpoint = slash_join(self._api_endpoint(), 'orgs/%s' % organization_id.lower()) result = http_client.get(org_endpoint, headers={'Accept': 'application/vnd.github.moondragon+json'}, timeout=5) return result.status_code == 200 def get_public_config(self): return { 'CLIENT_ID': self.client_id(), 'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), 'GITHUB_ENDPOINT': self._endpoint(), 'ORG_RESTRICT': self.config.get('ORG_RESTRICT', False) } class GoogleOAuthConfig(OAuthConfig): def __init__(self, config, key_name): super(GoogleOAuthConfig, self).__init__(config, key_name) def service_name(self): return 'Google' def authorize_endpoint(self): return 'https://accounts.google.com/o/oauth2/auth?response_type=code&' def token_endpoint(self): return 'https://accounts.google.com/o/oauth2/token' def user_endpoint(self): return 'https://www.googleapis.com/oauth2/v1/userinfo' def validate_client_id_and_secret(self, http_client, app_config): # To verify the Google client ID and secret, we hit the # https://www.googleapis.com/oauth2/v3/token endpoint with an invalid request. If the client # ID or secret are invalid, we get returned a 403 Unauthorized. Otherwise, we get returned # another response code. url = 'https://www.googleapis.com/oauth2/v3/token' data = { 'code': 'fakecode', 'client_id': self.client_id(), 'client_secret': self.client_secret(), 'grant_type': 'authorization_code', 'redirect_uri': 'http://example.com' } result = http_client.post(url, data=data, timeout=5) return result.status_code != 401 def get_public_config(self): return { 'CLIENT_ID': self.client_id(), 'AUTHORIZE_ENDPOINT': self.authorize_endpoint() } class GitLabOAuthConfig(OAuthConfig): def __init__(self, config, key_name): super(GitLabOAuthConfig, self).__init__(config, key_name) def _endpoint(self): return self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com') def api_endpoint(self): return self._endpoint() def get_public_url(self, suffix): return slash_join(self._endpoint(), suffix) def service_name(self): return 'GitLab' def authorize_endpoint(self): return slash_join(self._endpoint(), '/oauth/authorize') def token_endpoint(self): return slash_join(self._endpoint(), '/oauth/token') def validate_client_id_and_secret(self, http_client, app_config): url = self.token_endpoint() redirect_uri = self.get_redirect_uri(app_config, redirect_suffix='trigger') data = { 'code': 'fakecode', 'client_id': self.client_id(), 'client_secret': self.client_secret(), 'grant_type': 'authorization_code', 'redirect_uri': redirect_uri } # We validate by checking the error code we receive from this call. result = http_client.post(url, data=data, timeout=5) value = result.json() if not value: return False return value.get('error', '') != 'invalid_client' def get_public_config(self): return { 'CLIENT_ID': self.client_id(), 'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), 'GITLAB_ENDPOINT': self._endpoint(), } OIDC_WELLKNOWN = ".well-known/openid-configuration" PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour class OIDCConfig(OAuthConfig): def __init__(self, config, key_name): super(OIDCConfig, self).__init__(config, key_name) self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key) self._oidc_config = {} self._http_client = config['HTTPCLIENT'] if self.config.get('OIDC_SERVER'): self._load_via_discovery(config.get('DEBUGGING', False)) def _load_via_discovery(self, is_debugging): oidc_server = self.config['OIDC_SERVER'] if not oidc_server.startswith('https://') and not is_debugging: raise Exception('OIDC server must be accessed over SSL') discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN) discovery = self._http_client.get(discovery_url, timeout=5) if discovery.status_code / 100 != 2: raise Exception("Could not load OIDC discovery information") try: self._oidc_config = json.loads(discovery.text) except ValueError: logger.exception('Could not parse OIDC discovery for url: %s', discovery_url) raise Exception("Could not parse OIDC discovery information") def authorize_endpoint(self): return self._oidc_config.get('authorization_endpoint', '') + '?' def token_endpoint(self): return self._oidc_config.get('token_endpoint') def user_endpoint(self): return None def validate_client_id_and_secret(self, http_client, app_config): pass def get_public_config(self): return { 'CLIENT_ID': self.client_id(), 'AUTHORIZE_ENDPOINT': self.authorize_endpoint() } @property def issuer(self): return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER']) def get_public_key(self, force_refresh=False): """ Retrieves the public key for this handler. """ # If force_refresh is true, we expire all the items in the cache by setting the time to # the current time + the expiration TTL. if force_refresh: self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL) # Retrieve the public key from the cache. If the cache does not contain the public key, # it will internally call _get_public_key to retrieve it and then save it. The None is # a random key chose to be stored in the cache, and could be anything. return self._public_key_cache[None] def _get_public_key(self): """ Retrieves the public key for this handler. """ keys_url = self._oidc_config['jwks_uri'] keys = KEYS() keys.load_from_url(keys_url) if not list(keys): raise Exception('No keys provided by OIDC provider') rsa_key = list(keys)[0] rsa_key.deserialize() return rsa_key.key.exportKey('PEM') class DexOAuthConfig(OIDCConfig): def service_name(self): return 'Dex' @property def public_title(self): return self.get_public_config()['OIDC_TITLE'] def get_public_config(self): return { 'CLIENT_ID': self.client_id(), 'AUTHORIZE_ENDPOINT': self.authorize_endpoint(), # TODO(jschorr): This should ideally come from the Dex side. 'OIDC_TITLE': 'Dex', 'OIDC_LOGO': 'https://tectonic.com/assets/ico/favicon-96x96.png' }