362 lines
		
	
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			362 lines
		
	
	
	
		
			12 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import urlparse
 | |
| import github
 | |
| import json
 | |
| import logging
 | |
| import time
 | |
| 
 | |
| from cachetools.func import TTLCache
 | |
| from jwkest.jwk import KEYS, keyrep
 | |
| 
 | |
| logger = logging.getLogger(__name__)
 | |
| 
 | |
| class OAuthConfig(object):
 | |
|   def __init__(self, config, key_name):
 | |
|     self.key_name = key_name
 | |
|     self.config = config.get(key_name) or {}
 | |
| 
 | |
|   def service_name(self):
 | |
|     raise NotImplementedError
 | |
| 
 | |
|   def token_endpoint(self):
 | |
|     raise NotImplementedError
 | |
| 
 | |
|   def user_endpoint(self):
 | |
|     raise NotImplementedError
 | |
| 
 | |
|   def validate_client_id_and_secret(self, http_client, app_config):
 | |
|     raise NotImplementedError
 | |
| 
 | |
|   def client_id(self):
 | |
|     return self.config.get('CLIENT_ID')
 | |
| 
 | |
|   def client_secret(self):
 | |
|     return self.config.get('CLIENT_SECRET')
 | |
| 
 | |
|   def _get_url(self, endpoint, *args):
 | |
|     for arg in args:
 | |
|       endpoint = urlparse.urljoin(endpoint, arg)
 | |
| 
 | |
|     return endpoint
 | |
| 
 | |
|   def get_redirect_uri(self, app_config, redirect_suffix=''):
 | |
|     return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
 | |
|                                              app_config['SERVER_HOSTNAME'],
 | |
|                                              self.service_name().lower(),
 | |
|                                              redirect_suffix)
 | |
| 
 | |
| 
 | |
|   def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
 | |
|                               redirect_suffix='', client_auth=False):
 | |
|     payload = {
 | |
|       'code': code,
 | |
|       'grant_type': 'authorization_code',
 | |
|       'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix)
 | |
|     }
 | |
| 
 | |
|     headers = {
 | |
|       'Accept': 'application/json'
 | |
|     }
 | |
| 
 | |
|     auth = None
 | |
|     if client_auth:
 | |
|       auth = (self.client_id(), self.client_secret())
 | |
|     else:
 | |
|       payload['client_id'] = self.client_id()
 | |
|       payload['client_secret'] = self.client_secret()
 | |
| 
 | |
|     token_url = self.token_endpoint()
 | |
|     if form_encode:
 | |
|       get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
 | |
|     else:
 | |
|       get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
 | |
| 
 | |
|     json_data = get_access_token.json()
 | |
|     if not json_data:
 | |
|       return ''
 | |
| 
 | |
|     token = json_data.get('access_token', '')
 | |
|     return token
 | |
| 
 | |
| 
 | |
| class GithubOAuthConfig(OAuthConfig):
 | |
|   def __init__(self, config, key_name):
 | |
|     super(GithubOAuthConfig, self).__init__(config, key_name)
 | |
| 
 | |
|   def service_name(self):
 | |
|     return 'GitHub'
 | |
| 
 | |
|   def allowed_organizations(self):
 | |
|     if not self.config.get('ORG_RESTRICT', False):
 | |
|       return None
 | |
| 
 | |
|     allowed = self.config.get('ALLOWED_ORGANIZATIONS', None)
 | |
|     if allowed is None:
 | |
|       return None
 | |
| 
 | |
|     return [org.lower() for org in allowed]
 | |
| 
 | |
|   def get_public_url(self, suffix):
 | |
|     return '%s%s' % (self._endpoint(), suffix)
 | |
| 
 | |
|   def _endpoint(self):
 | |
|     endpoint = self.config.get('GITHUB_ENDPOINT', 'https://github.com')
 | |
|     if not endpoint.endswith('/'):
 | |
|       endpoint = endpoint + '/'
 | |
|     return endpoint
 | |
| 
 | |
|   def is_enterprise(self):
 | |
|     return self._endpoint().find('.github.com') < 0
 | |
| 
 | |
|   def authorize_endpoint(self):
 | |
|     return self._get_url(self._endpoint(), '/login/oauth/authorize')  + '?'
 | |
| 
 | |
|   def token_endpoint(self):
 | |
|     return self._get_url(self._endpoint(), '/login/oauth/access_token')
 | |
| 
 | |
|   def _api_endpoint(self):
 | |
|     return self.config.get('API_ENDPOINT', self._get_url(self._endpoint(), '/api/v3/'))
 | |
| 
 | |
|   def api_endpoint(self):
 | |
|     return self._api_endpoint()[0:-1]
 | |
| 
 | |
|   def user_endpoint(self):
 | |
|     api_endpoint = self._api_endpoint()
 | |
|     return self._get_url(api_endpoint, 'user')
 | |
| 
 | |
|   def email_endpoint(self):
 | |
|     api_endpoint = self._api_endpoint()
 | |
|     return self._get_url(api_endpoint, 'user/emails')
 | |
| 
 | |
|   def orgs_endpoint(self):
 | |
|     api_endpoint = self._api_endpoint()
 | |
|     return self._get_url(api_endpoint, 'user/orgs')
 | |
| 
 | |
|   def validate_client_id_and_secret(self, http_client, app_config):
 | |
|     # First: Verify that the github endpoint is actually Github by checking for the
 | |
|     # X-GitHub-Request-Id here.
 | |
|     api_endpoint = self._api_endpoint()
 | |
|     result = http_client.get(api_endpoint, auth=(self.client_id(), self.client_secret()), timeout=5)
 | |
|     if not 'X-GitHub-Request-Id' in result.headers:
 | |
|       raise Exception('Endpoint is not a Github (Enterprise) installation')
 | |
| 
 | |
|     # Next: Verify the client ID and secret.
 | |
|     # Note: The following code is a hack until such time as Github officially adds an API endpoint
 | |
|     # for verifying a {client_id, client_secret} pair. That being said, this hack was given to us
 | |
|     # *by a Github Engineer*, so I think it is okay for the time being :)
 | |
|     #
 | |
|     # TODO(jschorr): Replace with the real API call once added.
 | |
|     #
 | |
|     # Hitting the endpoint applications/{client_id}/tokens/foo will result in the following
 | |
|     # behavior IF the client_id is given as the HTTP username and the client_secret as the HTTP
 | |
|     # password:
 | |
|     #   - If the {client_id, client_secret} pair is invalid in some way, we get a 401 error.
 | |
|     #   - If the pair is valid, then we get a 404 because the 'foo' token does not exists.
 | |
|     validate_endpoint = self._get_url(api_endpoint, 'applications/%s/tokens/foo' % self.client_id())
 | |
|     result = http_client.get(validate_endpoint, auth=(self.client_id(), self.client_secret()),
 | |
|                                                 timeout=5)
 | |
|     return result.status_code == 404
 | |
| 
 | |
|   def validate_organization(self, organization_id, http_client):
 | |
|     api_endpoint = self._api_endpoint()
 | |
|     org_endpoint = self._get_url(api_endpoint, 'orgs/%s' % organization_id.lower())
 | |
| 
 | |
|     result = http_client.get(org_endpoint,
 | |
|       headers={'Accept': 'application/vnd.github.moondragon+json'},
 | |
|       timeout=5)
 | |
| 
 | |
|     return result.status_code == 200
 | |
| 
 | |
| 
 | |
|   def get_public_config(self):
 | |
|     return  {
 | |
|       'CLIENT_ID': self.client_id(),
 | |
|       'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
 | |
|       'GITHUB_ENDPOINT': self._endpoint(),
 | |
|       'ORG_RESTRICT': self.config.get('ORG_RESTRICT', False)
 | |
|     }
 | |
| 
 | |
| 
 | |
| 
 | |
| class GoogleOAuthConfig(OAuthConfig):
 | |
|   def __init__(self, config, key_name):
 | |
|     super(GoogleOAuthConfig, self).__init__(config, key_name)
 | |
| 
 | |
|   def service_name(self):
 | |
|     return 'Google'
 | |
| 
 | |
|   def authorize_endpoint(self):
 | |
|     return 'https://accounts.google.com/o/oauth2/auth?response_type=code&'
 | |
| 
 | |
|   def token_endpoint(self):
 | |
|     return 'https://accounts.google.com/o/oauth2/token'
 | |
| 
 | |
|   def user_endpoint(self):
 | |
|     return 'https://www.googleapis.com/oauth2/v1/userinfo'
 | |
| 
 | |
|   def validate_client_id_and_secret(self, http_client, app_config):
 | |
|     # To verify the Google client ID and secret, we hit the
 | |
|     # https://www.googleapis.com/oauth2/v3/token endpoint with an invalid request. If the client
 | |
|     # ID or secret are invalid, we get returned a 403 Unauthorized. Otherwise, we get returned
 | |
|     # another response code.
 | |
|     url = 'https://www.googleapis.com/oauth2/v3/token'
 | |
|     data = {
 | |
|       'code': 'fakecode',
 | |
|       'client_id': self.client_id(),
 | |
|       'client_secret': self.client_secret(),
 | |
|       'grant_type': 'authorization_code',
 | |
|       'redirect_uri': 'http://example.com'
 | |
|     }
 | |
| 
 | |
|     result = http_client.post(url, data=data, timeout=5)
 | |
|     return result.status_code != 401
 | |
| 
 | |
|   def get_public_config(self):
 | |
|     return  {
 | |
|       'CLIENT_ID': self.client_id(),
 | |
|       'AUTHORIZE_ENDPOINT': self.authorize_endpoint()
 | |
|     }
 | |
| 
 | |
| 
 | |
| class GitLabOAuthConfig(OAuthConfig):
 | |
|   def __init__(self, config, key_name):
 | |
|     super(GitLabOAuthConfig, self).__init__(config, key_name)
 | |
| 
 | |
|   def _endpoint(self):
 | |
|     endpoint = self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com')
 | |
|     if not endpoint.endswith('/'):
 | |
|       endpoint = endpoint + '/'
 | |
|     return endpoint
 | |
| 
 | |
|   def service_name(self):
 | |
|     return 'GitLab'
 | |
| 
 | |
|   def authorize_endpoint(self):
 | |
|     return self._get_url(self._endpoint(), '/oauth/authorize')
 | |
| 
 | |
|   def token_endpoint(self):
 | |
|     return self._get_url(self._endpoint(), '/oauth/token')
 | |
| 
 | |
|   def validate_client_id_and_secret(self, http_client, app_config):
 | |
|     url = self.token_endpoint()
 | |
|     redirect_uri = self.get_redirect_uri(app_config, redirect_suffix='trigger')
 | |
|     data = {
 | |
|       'code': 'fakecode',
 | |
|       'client_id': self.client_id(),
 | |
|       'client_secret': self.client_secret(),
 | |
|       'grant_type': 'authorization_code',
 | |
|       'redirect_uri': redirect_uri
 | |
|     }
 | |
| 
 | |
|     # We validate by checking the error code we receive from this call.
 | |
|     result = http_client.post(url, data=data, timeout=5)
 | |
|     value = result.json()
 | |
|     if not value:
 | |
|       return False
 | |
| 
 | |
|     return value.get('error', '') != 'invalid_client'
 | |
| 
 | |
|   def get_public_config(self):
 | |
|     return {
 | |
|       'CLIENT_ID': self.client_id(),
 | |
|       'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
 | |
|       'GITLAB_ENDPOINT': self._endpoint(),
 | |
|     }
 | |
| 
 | |
| 
 | |
| OIDC_WELLKNOWN = ".well-known/openid-configuration"
 | |
| PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
 | |
| 
 | |
| class OIDCConfig(OAuthConfig):
 | |
|   def __init__(self, config, key_name):
 | |
|     super(OIDCConfig, self).__init__(config, key_name)
 | |
| 
 | |
|     self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
 | |
|     self._oidc_config = {}
 | |
|     self._http_client = config['HTTPCLIENT']
 | |
| 
 | |
|     if self.config.get('OIDC_SERVER'):
 | |
|       self._load_via_discovery(config['DEBUGGING'])
 | |
| 
 | |
|   def _load_via_discovery(self, is_debugging):
 | |
|     oidc_server = self.config['OIDC_SERVER']
 | |
|     if not oidc_server.startswith('https://') and not is_debugging:
 | |
|       raise Exception('OIDC server must be accessed over SSL')
 | |
| 
 | |
|     discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
 | |
|     discovery = self._http_client.get(discovery_url, timeout=5)
 | |
| 
 | |
|     if discovery.status_code / 100 != 2:
 | |
|       raise Exception("Could not load OIDC discovery information")
 | |
| 
 | |
|     try:
 | |
|       self._oidc_config = json.loads(discovery.text)
 | |
|     except ValueError:
 | |
|       logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
 | |
|       raise Exception("Could not parse OIDC discovery information")
 | |
| 
 | |
|   def authorize_endpoint(self):
 | |
|     return self._oidc_config.get('authorization_endpoint', '') + '?'
 | |
| 
 | |
|   def token_endpoint(self):
 | |
|     return self._oidc_config.get('token_endpoint')
 | |
| 
 | |
|   def user_endpoint(self):
 | |
|     return None
 | |
| 
 | |
|   def validate_client_id_and_secret(self, http_client, app_config):
 | |
|     pass
 | |
| 
 | |
|   def get_public_config(self):
 | |
|     return {
 | |
|       'CLIENT_ID': self.client_id(),
 | |
|       'AUTHORIZE_ENDPOINT': self.authorize_endpoint()
 | |
|     }
 | |
| 
 | |
|   @property
 | |
|   def issuer(self):
 | |
|     return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
 | |
| 
 | |
|   def get_public_key(self, force_refresh=False):
 | |
|     """ Retrieves the public key for this handler. """
 | |
|     # If force_refresh is true, we expire all the items in the cache by setting the time to
 | |
|     # the current time + the expiration TTL.
 | |
|     if force_refresh:
 | |
|       self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
 | |
| 
 | |
|     # Retrieve the public key from the cache. If the cache does not contain the public key,
 | |
|     # it will internally call _get_public_key to retrieve it and then save it. The None is
 | |
|     # a random key chose to be stored in the cache, and could be anything.
 | |
|     return self._public_key_cache[None]
 | |
| 
 | |
|   def _get_public_key(self):
 | |
|     """ Retrieves the public key for this handler. """
 | |
|     keys_url = self._oidc_config['jwks_uri']
 | |
| 
 | |
|     keys = KEYS()
 | |
|     keys.load_from_url(keys_url)
 | |
| 
 | |
|     if not list(keys):
 | |
|       raise Exception('No keys provided by OIDC provider')
 | |
| 
 | |
|     rsa_key = list(keys)[0]
 | |
|     rsa_key.deserialize()
 | |
|     return rsa_key.key.exportKey('PEM')
 | |
| 
 | |
| 
 | |
| class DexOAuthConfig(OIDCConfig):
 | |
|   def service_name(self):
 | |
|     return 'Dex'
 | |
| 
 | |
|   @property
 | |
|   def public_title(self):
 | |
|     return self.get_public_config()['OIDC_TITLE']
 | |
| 
 | |
|   def get_public_config(self):
 | |
|     return {
 | |
|       'CLIENT_ID': self.client_id(),
 | |
|       'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
 | |
| 
 | |
|       # TODO(jschorr): This should ideally come from the Dex side.
 | |
|       'OIDC_TITLE': 'Dex',
 | |
|       'OIDC_LOGO': 'https://tectonic.com/assets/ico/favicon-96x96.png'
 | |
|     }
 |