initial import for Open Source 🎉

This commit is contained in:
Jimmy Zelinskie 2019-11-12 11:09:47 -05:00
parent 1898c361f3
commit 9c0dd3b722
2048 changed files with 218743 additions and 0 deletions

0
oauth/__init__.py Normal file
View file

193
oauth/base.py Normal file
View file

@ -0,0 +1,193 @@
import copy
import logging
import urllib
import urlparse
from abc import ABCMeta, abstractmethod
from six import add_metaclass
from util.config import URLSchemeAndHostname
logger = logging.getLogger(__name__)
class OAuthEndpoint(object):
def __init__(self, base_url, params=None):
self.base_url = base_url
self.params = params or {}
def with_param(self, name, value):
params_copy = copy.copy(self.params)
params_copy[name] = value
return OAuthEndpoint(self.base_url, params_copy)
def with_params(self, parameters):
params_copy = copy.copy(self.params)
params_copy.update(parameters)
return OAuthEndpoint(self.base_url, params_copy)
def to_url(self):
(scheme, netloc, path, _, fragment) = urlparse.urlsplit(self.base_url)
updated_query = urllib.urlencode(self.params)
return urlparse.urlunsplit((scheme, netloc, path, updated_query, fragment))
class OAuthExchangeCodeException(Exception):
""" Exception raised if a code exchange fails. """
pass
class OAuthGetUserInfoException(Exception):
""" Exception raised if a call to get user information fails. """
pass
@add_metaclass(ABCMeta)
class OAuthService(object):
""" A base class for defining an external service, exposed via OAuth. """
def __init__(self, config, key_name):
self.key_name = key_name
self.config = config.get(key_name) or {}
@abstractmethod
def service_id(self):
""" The internal ID for this service. Must match the URL portion for the service, e.g. `github`
"""
pass
@abstractmethod
def service_name(self):
""" The user-readable name for the service, e.g. `GitHub`"""
pass
@abstractmethod
def token_endpoint(self):
""" Returns the endpoint at which the OAuth code can be exchanged for a token. """
pass
@abstractmethod
def user_endpoint(self):
""" Returns the endpoint at which user information can be looked up. """
pass
@abstractmethod
def authorize_endpoint(self):
""" Returns the for authorization of the OAuth service. """
pass
@abstractmethod
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
""" Performs validation of the client ID and secret, raising an exception on failure. """
pass
def requires_form_encoding(self):
""" Returns True if form encoding is necessary for the exchange_code_for_token call. """
return False
def client_id(self):
return self.config.get('CLIENT_ID')
def client_secret(self):
return self.config.get('CLIENT_SECRET')
def login_binding_field(self):
""" Returns the name of the field (`username` or `email`) used for auto binding an external
login service account to an *internal* login service account. For example, if the external
login service is GitHub and the internal login service is LDAP, a value of `email` here
will cause login-with-Github to conduct a search (via email) in LDAP for a user, an auto
bind the external and internal users together. May return None, in which case no binding
is performing, and login with this external account will simply create a new account in the
database.
"""
return self.config.get('LOGIN_BINDING_FIELD', None)
def get_auth_url(self, url_scheme_and_hostname, redirect_suffix, csrf_token, scopes):
""" Retrieves the authorization URL for this login service. """
redirect_uri = '%s/oauth2/%s/callback%s' % (url_scheme_and_hostname.get_url(),
self.service_id(),
redirect_suffix)
params = {
'client_id': self.client_id(),
'redirect_uri': redirect_uri,
'scope': ' '.join(scopes),
'state': csrf_token,
}
return self.authorize_endpoint().with_params(params).to_url()
def get_redirect_uri(self, url_scheme_and_hostname, redirect_suffix=''):
return '%s://%s/oauth2/%s/callback%s' % (url_scheme_and_hostname.url_scheme,
url_scheme_and_hostname.hostname,
self.service_id(),
redirect_suffix)
def get_user_info(self, http_client, token):
token_param = {
'alt': 'json',
}
headers = {
'Authorization': 'Bearer %s' % token,
}
got_user = http_client.get(self.user_endpoint().to_url(), params=token_param, headers=headers)
if got_user.status_code // 100 != 2:
raise OAuthGetUserInfoException('Non-2XX response code for user_info call: %s' %
got_user.status_code)
user_info = got_user.json()
if user_info is None:
raise OAuthGetUserInfoException()
return user_info
def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
redirect_suffix='', client_auth=False):
""" Exchanges an OAuth access code for the associated OAuth token. """
json_data = self.exchange_code(app_config, http_client, code, form_encode, redirect_suffix,
client_auth)
access_token = json_data.get('access_token', None)
if access_token is None:
logger.debug('Got successful get_access_token response with missing token: %s', json_data)
raise OAuthExchangeCodeException('Missing `access_token` in OAuth response')
return access_token
def exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='',
client_auth=False):
""" Exchanges an OAuth access code for associated OAuth token and other data. """
url_scheme_and_hostname = URLSchemeAndHostname.from_app_config(app_config)
payload = {
'code': code,
'grant_type': 'authorization_code',
'redirect_uri': self.get_redirect_uri(url_scheme_and_hostname, redirect_suffix)
}
headers = {
'Accept': 'application/json'
}
auth = None
if client_auth:
auth = (self.client_id(), self.client_secret())
else:
payload['client_id'] = self.client_id()
payload['client_secret'] = self.client_secret()
token_url = self.token_endpoint().to_url()
if form_encode:
get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
else:
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
if get_access_token.status_code // 100 != 2:
logger.debug('Got get_access_token response %s', get_access_token.text)
raise OAuthExchangeCodeException('Got non-2XX response for code exchange: %s' %
get_access_token.status_code)
json_data = get_access_token.json()
if not json_data:
raise OAuthExchangeCodeException('Got non-JSON response for code exchange')
if 'error' in json_data:
raise OAuthExchangeCodeException(json_data.get('error_description', json_data['error']))
return json_data

96
oauth/login.py Normal file
View file

@ -0,0 +1,96 @@
import logging
from abc import ABCMeta, abstractmethod
from six import add_metaclass
import features
from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException
logger = logging.getLogger(__name__)
class OAuthLoginException(Exception):
""" Exception raised if a login operation fails. """
pass
@add_metaclass(ABCMeta)
class OAuthLoginService(OAuthService):
""" A base class for defining an OAuth-compliant service that can be used for, amongst other
things, login and authentication. """
@abstractmethod
def login_enabled(self):
""" Returns true if the login service is enabled. """
pass
@abstractmethod
def get_login_service_id(self, user_info):
""" Returns the internal ID for the given user under this login service. """
pass
@abstractmethod
def get_login_service_username(self, user_info):
""" Returns the username for the given user under this login service. """
pass
@abstractmethod
def get_verified_user_email(self, app_config, http_client, token, user_info):
""" Returns the verified email address for the given user, if any or None if none. """
pass
@abstractmethod
def get_icon(self):
""" Returns the icon to display for this login service. """
pass
@abstractmethod
def get_login_scopes(self):
""" Returns the list of scopes for login for this service. """
pass
def service_verify_user_info_for_login(self, app_config, http_client, token, user_info):
""" Performs service-specific verification of user information for login. On failure, a service
should raise a OAuthLoginService.
"""
# By default, does nothing.
pass
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
""" Exchanges the given OAuth access code for user information on behalf of a user trying to
login or attach their account. Raises a OAuthLoginService exception on failure. Returns
a tuple consisting of (service_id, service_username, email)
"""
# Retrieve the token for the OAuth code.
try:
token = self.exchange_code_for_token(app_config, http_client, code,
redirect_suffix=redirect_suffix,
form_encode=self.requires_form_encoding())
except OAuthExchangeCodeException as oce:
raise OAuthLoginException(str(oce))
# Retrieve the user's information with the token.
try:
user_info = self.get_user_info(http_client, token)
except OAuthGetUserInfoException as oge:
raise OAuthLoginException(str(oge))
if user_info.get('id', None) is None:
logger.debug('Got user info response %s', user_info)
raise OAuthLoginException('Missing `id` column in returned user information')
# Perform any custom verification for this login service.
self.service_verify_user_info_for_login(app_config, http_client, token, user_info)
# Retrieve the user's email address (if necessary).
email_address = self.get_verified_user_email(app_config, http_client, token, user_info)
if features.MAILING and email_address is None:
raise OAuthLoginException('A verified email address is required to login with this service')
service_user_id = self.get_login_service_id(user_info)
service_username = self.get_login_service_username(user_info)
logger.debug('Completed successful exchange for service %s: %s, %s, %s',
self.service_id(), service_user_id, service_username, email_address)
return (service_user_id, service_username, email_address)

37
oauth/loginmanager.py Normal file
View file

@ -0,0 +1,37 @@
from oauth.services.github import GithubOAuthService
from oauth.services.google import GoogleOAuthService
from oauth.oidc import OIDCLoginService
CUSTOM_LOGIN_SERVICES = {
'GITHUB_LOGIN_CONFIG': GithubOAuthService,
'GOOGLE_LOGIN_CONFIG': GoogleOAuthService,
}
PREFIX_BLACKLIST = ['ldap', 'jwt', 'keystone']
class OAuthLoginManager(object):
""" Helper class which manages all registered OAuth login services. """
def __init__(self, config, client=None):
self.services = []
# Register the endpoints for each of the OAuth login services.
for key in config.keys():
# All keys which end in _LOGIN_CONFIG setup a login service.
if key.endswith('_LOGIN_CONFIG'):
if key in CUSTOM_LOGIN_SERVICES:
custom_service = CUSTOM_LOGIN_SERVICES[key](config, key)
if custom_service.login_enabled(config):
self.services.append(custom_service)
else:
prefix = key.rstrip('_LOGIN_CONFIG').lower()
if prefix in PREFIX_BLACKLIST:
raise Exception('Cannot use reserved config name %s' % key)
self.services.append(OIDCLoginService(config, key, client=client))
def get_service(self, service_id):
for service in self.services:
if service.service_id() == service_id:
return service
return None

328
oauth/oidc.py Normal file
View file

@ -0,0 +1,328 @@
import time
import json
import logging
import urlparse
import jwt
from cachetools.func import lru_cache
from cachetools.ttl import TTLCache
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_der_public_key
from jwkest.jwk import KEYS
from oauth.base import (OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException,
OAuthEndpoint)
from oauth.login import OAuthLoginException
from util.security.jwtutil import decode, InvalidTokenError
logger = logging.getLogger(__name__)
OIDC_WELLKNOWN = ".well-known/openid-configuration"
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
ALLOWED_ALGORITHMS = ['RS256']
JWT_CLOCK_SKEW_SECONDS = 30
class DiscoveryFailureException(Exception):
""" Exception raised when OIDC discovery fails. """
pass
class PublicKeyLoadException(Exception):
""" Exception raised if loading the OIDC public key fails. """
pass
class OIDCLoginService(OAuthService):
""" Defines a generic service for all OpenID-connect compatible login services. """
def __init__(self, config, key_name, client=None):
super(OIDCLoginService, self).__init__(config, key_name)
self._id = key_name[0:key_name.find('_')].lower()
self._http_client = client or config.get('HTTPCLIENT')
self._mailing = config.get('FEATURE_MAILING', False)
self._public_key_cache = _PublicKeyCache(self, 1, PUBLIC_KEY_CACHE_TTL)
def service_id(self):
return self._id
def service_name(self):
return self.config.get('SERVICE_NAME', self.service_id())
def get_icon(self):
return self.config.get('SERVICE_ICON', 'fa-user-circle')
def get_login_scopes(self):
default_scopes = ['openid']
if self.user_endpoint() is not None:
default_scopes.append('profile')
if self._mailing:
default_scopes.append('email')
supported_scopes = self._oidc_config().get('scopes_supported', default_scopes)
login_scopes = self.config.get('LOGIN_SCOPES') or supported_scopes
return list(set(login_scopes) & set(supported_scopes))
def authorize_endpoint(self):
return self._get_endpoint('authorization_endpoint').with_param('response_type', 'code')
def token_endpoint(self):
return self._get_endpoint('token_endpoint')
def user_endpoint(self):
return self._get_endpoint('userinfo_endpoint')
def _get_endpoint(self, endpoint_key, **kwargs):
""" Returns the OIDC endpoint with the given key found in the OIDC discovery
document, with the given kwargs added as query parameters. Additionally,
any defined parameters found in the OIDC configuration block are also
added.
"""
endpoint = self._oidc_config().get(endpoint_key, '')
if not endpoint:
return None
(scheme, netloc, path, query, fragment) = urlparse.urlsplit(endpoint)
# Add the query parameters from the kwargs and the config.
custom_parameters = self.config.get('OIDC_ENDPOINT_CUSTOM_PARAMS', {}).get(endpoint_key, {})
query_params = urlparse.parse_qs(query, keep_blank_values=True)
query_params.update(kwargs)
query_params.update(custom_parameters)
return OAuthEndpoint(urlparse.urlunsplit((scheme, netloc, path, {}, fragment)), query_params)
def validate(self):
return bool(self.get_login_scopes())
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
# TODO: find a way to verify client secret too.
check_auth_url = http_client.get(self.get_auth_url(url_scheme_and_hostname, '', '', []))
if check_auth_url.status_code // 100 != 2:
raise Exception('Got non-200 status code for authorization endpoint')
def requires_form_encoding(self):
return True
def get_public_config(self):
return {
'CLIENT_ID': self.client_id(),
'OIDC': True,
}
def exchange_code_for_tokens(self, app_config, http_client, code, redirect_suffix):
# Exchange the code for the access token and id_token
try:
json_data = self.exchange_code(app_config, http_client, code,
redirect_suffix=redirect_suffix,
form_encode=self.requires_form_encoding())
except OAuthExchangeCodeException as oce:
raise OAuthLoginException(str(oce))
# Make sure we received both.
access_token = json_data.get('access_token', None)
if access_token is None:
logger.debug('Missing access_token in response: %s', json_data)
raise OAuthLoginException('Missing `access_token` in OIDC response')
id_token = json_data.get('id_token', None)
if id_token is None:
logger.debug('Missing id_token in response: %s', json_data)
raise OAuthLoginException('Missing `id_token` in OIDC response')
return id_token, access_token
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
# Exchange the code for the access token and id_token
id_token, access_token = self.exchange_code_for_tokens(app_config, http_client, code,
redirect_suffix)
# Decode the id_token.
try:
decoded_id_token = self.decode_user_jwt(id_token)
except InvalidTokenError as ite:
logger.exception('Got invalid token error on OIDC decode: %s', ite)
raise OAuthLoginException('Could not decode OIDC token')
except PublicKeyLoadException as pke:
logger.exception('Could not load public key during OIDC decode: %s', pke)
raise OAuthLoginException('Could find public OIDC key')
# If there is a user endpoint, use it to retrieve the user's information. Otherwise, we use
# the decoded ID token.
if self.user_endpoint():
# Retrieve the user information.
try:
user_info = self.get_user_info(http_client, access_token)
except OAuthGetUserInfoException as oge:
raise OAuthLoginException(str(oge))
else:
user_info = decoded_id_token
# Verify subs.
if user_info['sub'] != decoded_id_token['sub']:
logger.debug('Mismatch in `sub` returned by OIDC user info endpoint: %s vs %s',
user_info['sub'], decoded_id_token['sub'])
raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint')
# Check if we have a verified email address.
if self.config.get('VERIFIED_EMAIL_CLAIM_NAME'):
email_address = user_info.get(self.config['VERIFIED_EMAIL_CLAIM_NAME'])
else:
email_address = user_info.get('email') if user_info.get('email_verified') else None
logger.debug('Found e-mail address `%s` for sub `%s`', email_address, user_info['sub'])
if self._mailing:
if email_address is None:
raise OAuthLoginException('A verified email address is required to login with this service')
# Check for a preferred username.
if self.config.get('PREFERRED_USERNAME_CLAIM_NAME'):
lusername = user_info.get(self.config['PREFERRED_USERNAME_CLAIM_NAME'])
else:
lusername = user_info.get('preferred_username')
if lusername is None:
# Note: Active Directory provides `unique_name` and `upn`.
# https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-id-and-access-tokens
lusername = user_info.get('unique_name', user_info.get('upn'))
if lusername is None:
lusername = user_info['sub']
if lusername.find('@') >= 0:
lusername = lusername[0:lusername.find('@')]
return decoded_id_token['sub'], lusername, email_address
@property
def _issuer(self):
# Read the issuer from the OIDC config, falling back to the configured OIDC server.
issuer = self._oidc_config().get('issuer', self.config['OIDC_SERVER'])
# If specified, use the overridden OIDC issuer.
return self.config.get('OIDC_ISSUER', issuer)
@lru_cache(maxsize=1)
def _oidc_config(self):
if self.config.get('OIDC_SERVER'):
return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False))
else:
return {}
def _load_oidc_config_via_discovery(self, is_debugging):
""" Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True,
non-secure connections are alllowed. Raises an DiscoveryFailureException on failure.
"""
oidc_server = self.config['OIDC_SERVER']
if not oidc_server.startswith('https://') and not is_debugging:
raise DiscoveryFailureException('OIDC server must be accessed over SSL')
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
discovery = self._http_client.get(discovery_url, timeout=5, verify=is_debugging is False)
if discovery.status_code // 100 != 2:
logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text)
raise DiscoveryFailureException("Could not load OIDC discovery information")
try:
return json.loads(discovery.text)
except ValueError:
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
raise DiscoveryFailureException("Could not parse OIDC discovery information")
def decode_user_jwt(self, token):
""" Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError
exception on an invalid token or a PublicKeyLoadException if the public key could not be
loaded for decoding.
"""
# Find the key to use.
headers = jwt.get_unverified_header(token)
kid = headers.get('kid', None)
if kid is None:
raise InvalidTokenError('Missing `kid` header')
logger.debug('Using key `%s`, attempting to decode token `%s` with aud `%s` and iss `%s`',
kid, token, self.client_id(), self._issuer)
try:
return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False))
except InvalidTokenError as ite:
logger.warning('Could not decode token `%s` for OIDC: %s. Will attempt again after ' +
'retrieving public keys.', token, ite)
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
try:
return decode(token, self._get_public_key(kid, force_refresh=True),
algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False))
except InvalidTokenError as ite:
logger.warning('Could not decode token `%s` for OIDC: %s. Attempted again after ' +
'retrieving public keys.', token, ite)
# Decode again with verify=False, and log the decoded token to allow for easier debugging.
nonverified = decode(token, self._get_public_key(kid, force_refresh=True),
algorithms=ALLOWED_ALGORITHMS,
audience=self.client_id(),
issuer=self._issuer,
leeway=JWT_CLOCK_SKEW_SECONDS,
options=dict(require_nbf=False, verify=False))
logger.debug('Got an error when trying to verify OIDC JWT: %s', nonverified)
raise ite
def _get_public_key(self, kid, force_refresh=False):
""" Retrieves the public key for this handler with the given kid. Raises a
PublicKeyLoadException on failure. """
# If force_refresh is true, we expire all the items in the cache by setting the time to
# the current time + the expiration TTL.
if force_refresh:
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
# Retrieve the public key from the cache. If the cache does not contain the public key,
# it will internally call _load_public_key to retrieve it and then save it.
return self._public_key_cache[kid]
class _PublicKeyCache(TTLCache):
def __init__(self, login_service, *args, **kwargs):
super(_PublicKeyCache, self).__init__(*args, **kwargs)
self._login_service = login_service
def __missing__(self, kid):
""" Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException
on failure.
"""
keys_url = self._login_service._oidc_config()['jwks_uri']
# Load the keys.
try:
keys = KEYS()
keys.load_from_url(keys_url, verify=not self._login_service.config.get('DEBUGGING', False))
except Exception as ex:
logger.exception('Exception loading public key')
raise PublicKeyLoadException(str(ex))
# Find the matching key.
keys_found = keys.by_kid(kid)
if len(keys_found) == 0:
raise PublicKeyLoadException('Public key %s not found' % kid)
rsa_keys = [key for key in keys_found if key.kty == 'RSA']
if len(rsa_keys) == 0:
raise PublicKeyLoadException('No RSA form of public key %s not found' % kid)
matching_key = rsa_keys[0]
matching_key.deserialize()
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
# issues.
final_key = load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend())
self[kid] = final_key
return final_key

View file

180
oauth/services/github.py Normal file
View file

@ -0,0 +1,180 @@
import logging
from oauth.base import OAuthEndpoint
from oauth.login import OAuthLoginService, OAuthLoginException
from util import slash_join
logger = logging.getLogger(__name__)
class GithubOAuthService(OAuthLoginService):
def __init__(self, config, key_name):
super(GithubOAuthService, self).__init__(config, key_name)
def login_enabled(self, config):
return config.get('FEATURE_GITHUB_LOGIN', False)
def service_id(self):
return 'github'
def service_name(self):
if self.is_enterprise():
return 'GitHub Enterprise'
return 'GitHub'
def get_icon(self):
return 'fa-github'
def get_login_scopes(self):
if self.config.get('ORG_RESTRICT'):
return ['user:email', 'read:org']
return ['user:email']
def allowed_organizations(self):
if not self.config.get('ORG_RESTRICT', False):
return None
allowed = self.config.get('ALLOWED_ORGANIZATIONS', None)
if allowed is None:
return None
return [org.lower() for org in allowed]
def get_public_url(self, suffix):
return slash_join(self._endpoint(), suffix)
def _endpoint(self):
return self.config.get('GITHUB_ENDPOINT', 'https://github.com')
def is_enterprise(self):
return self._api_endpoint().find('.github.com') < 0
def authorize_endpoint(self):
return OAuthEndpoint(slash_join(self._endpoint(), '/login/oauth/authorize'))
def token_endpoint(self):
return OAuthEndpoint(slash_join(self._endpoint(), '/login/oauth/access_token'))
def user_endpoint(self):
return OAuthEndpoint(slash_join(self._api_endpoint(), 'user'))
def _api_endpoint(self):
return self.config.get('API_ENDPOINT', slash_join(self._endpoint(), '/api/v3/'))
def api_endpoint(self):
endpoint = self._api_endpoint()
if endpoint.endswith('/'):
return endpoint[0:-1]
return endpoint
def email_endpoint(self):
return slash_join(self._api_endpoint(), 'user/emails')
def orgs_endpoint(self):
return slash_join(self._api_endpoint(), 'user/orgs')
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
# First: Verify that the github endpoint is actually Github by checking for the
# X-GitHub-Request-Id here.
api_endpoint = self._api_endpoint()
result = http_client.get(api_endpoint, auth=(self.client_id(), self.client_secret()), timeout=5)
if not 'X-GitHub-Request-Id' in result.headers:
raise Exception('Endpoint is not a Github (Enterprise) installation')
# Next: Verify the client ID and secret.
# Note: The following code is a hack until such time as Github officially adds an API endpoint
# for verifying a {client_id, client_secret} pair. This workaround was given to us
# *by a Github Engineer* (Jan 8, 2015).
#
# TODO: Replace with the real API call once added.
#
# Hitting the endpoint applications/{client_id}/tokens/foo will result in the following
# behavior IF the client_id is given as the HTTP username and the client_secret as the HTTP
# password:
# - If the {client_id, client_secret} pair is invalid in some way, we get a 401 error.
# - If the pair is valid, then we get a 404 because the 'foo' token does not exists.
validate_endpoint = slash_join(api_endpoint, 'applications/%s/tokens/foo' % self.client_id())
result = http_client.get(validate_endpoint, auth=(self.client_id(), self.client_secret()),
timeout=5)
return result.status_code == 404
def validate_organization(self, organization_id, http_client):
org_endpoint = slash_join(self._api_endpoint(), 'orgs/%s' % organization_id.lower())
result = http_client.get(org_endpoint,
headers={'Accept': 'application/vnd.github.moondragon+json'},
timeout=5)
return result.status_code == 200
def get_public_config(self):
return {
'CLIENT_ID': self.client_id(),
'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url(),
'GITHUB_ENDPOINT': self._endpoint(),
'ORG_RESTRICT': self.config.get('ORG_RESTRICT', False)
}
def get_login_service_id(self, user_info):
return user_info['id']
def get_login_service_username(self, user_info):
return user_info['login']
def get_verified_user_email(self, app_config, http_client, token, user_info):
v3_media_type = {
'Accept': 'application/vnd.github.v3'
}
token_param = {
'access_token': token,
}
# Find the e-mail address for the user: we will accept any email, but we prefer the primary
get_email = http_client.get(self.email_endpoint(), params=token_param, headers=v3_media_type)
if get_email.status_code // 100 != 2:
raise OAuthLoginException('Got non-2XX status code for emails endpoint: %s' %
get_email.status_code)
verified_emails = [email for email in get_email.json() if email['verified']]
primary_emails = [email for email in get_email.json() if email['primary']]
# Special case: We don't care about whether an e-mail address is "verified" under GHE.
if self.is_enterprise() and not verified_emails:
verified_emails = primary_emails
allowed_emails = (primary_emails or verified_emails or [])
return allowed_emails[0]['email'] if len(allowed_emails) > 0 else None
def service_verify_user_info_for_login(self, app_config, http_client, token, user_info):
# Retrieve the user's orgnizations (if organization filtering is turned on)
if self.allowed_organizations() is None:
return
moondragon_media_type = {
'Accept': 'application/vnd.github.moondragon+json'
}
token_param = {
'access_token': token,
}
get_orgs = http_client.get(self.orgs_endpoint(), params=token_param,
headers=moondragon_media_type)
if get_orgs.status_code // 100 != 2:
logger.debug('get_orgs response: %s', get_orgs.json())
raise OAuthLoginException('Got non-2XX response for org lookup: %s' %
get_orgs.status_code)
organizations = set([org.get('login').lower() for org in get_orgs.json()])
matching_organizations = organizations & set(self.allowed_organizations())
if not matching_organizations:
logger.debug('Found organizations %s, but expected one of %s', organizations,
self.allowed_organizations())
err = """You are not a member of an allowed GitHub organization.
Please contact your system administrator if you believe this is in error."""
raise OAuthLoginException(err)

60
oauth/services/gitlab.py Normal file
View file

@ -0,0 +1,60 @@
from oauth.base import OAuthService, OAuthEndpoint
from util import slash_join
class GitLabOAuthService(OAuthService):
def __init__(self, config, key_name):
super(GitLabOAuthService, self).__init__(config, key_name)
def service_id(self):
return 'gitlab'
def service_name(self):
return 'GitLab'
def _endpoint(self):
return self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com')
def user_endpoint(self):
raise NotImplementedError
def api_endpoint(self):
return self._endpoint()
def get_public_url(self, suffix):
return slash_join(self._endpoint(), suffix)
def authorize_endpoint(self):
return OAuthEndpoint(slash_join(self._endpoint(), '/oauth/authorize'))
def token_endpoint(self):
return OAuthEndpoint(slash_join(self._endpoint(), '/oauth/token'))
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
# We validate the client ID and secret by hitting the OAuth token exchange endpoint with
# the real client ID and secret, but a fake auth code to exchange. Gitlab's implementation will
# return `invalid_client` as the `error` if the client ID or secret is invalid; otherwise, it
# will return another error.
url = self.token_endpoint().to_url()
redirect_uri = self.get_redirect_uri(url_scheme_and_hostname, redirect_suffix='trigger')
data = {
'code': 'fakecode',
'client_id': self.client_id(),
'client_secret': self.client_secret(),
'grant_type': 'authorization_code',
'redirect_uri': redirect_uri
}
# We validate by checking the error code we receive from this call.
result = http_client.post(url, data=data, timeout=5)
value = result.json()
if not value:
return False
return value.get('error', '') != 'invalid_client'
def get_public_config(self):
return {
'CLIENT_ID': self.client_id(),
'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url(),
'GITLAB_ENDPOINT': self._endpoint(),
}

81
oauth/services/google.py Normal file
View file

@ -0,0 +1,81 @@
from oauth.base import OAuthEndpoint
from oauth.login import OAuthLoginService
def _get_email_username(email_address):
username = email_address
at = username.find('@')
if at > 0:
username = username[0:at]
return username
class GoogleOAuthService(OAuthLoginService):
def __init__(self, config, key_name):
super(GoogleOAuthService, self).__init__(config, key_name)
def login_enabled(self, config):
return config.get('FEATURE_GOOGLE_LOGIN', False)
def service_id(self):
return 'google'
def service_name(self):
return 'Google'
def get_icon(self):
return 'fa-google'
def get_login_scopes(self):
return ['openid', 'email']
def authorize_endpoint(self):
return OAuthEndpoint('https://accounts.google.com/o/oauth2/auth',
params=dict(response_type='code'))
def token_endpoint(self):
return OAuthEndpoint('https://accounts.google.com/o/oauth2/token')
def user_endpoint(self):
return OAuthEndpoint('https://www.googleapis.com/oauth2/v1/userinfo')
def requires_form_encoding(self):
return True
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
# To verify the Google client ID and secret, we hit the
# https://www.googleapis.com/oauth2/v3/token endpoint with an invalid request. If the client
# ID or secret are invalid, we get returned a 403 Unauthorized. Otherwise, we get returned
# another response code.
url = 'https://www.googleapis.com/oauth2/v3/token'
data = {
'code': 'fakecode',
'client_id': self.client_id(),
'client_secret': self.client_secret(),
'grant_type': 'authorization_code',
'redirect_uri': 'http://example.com'
}
result = http_client.post(url, data=data, timeout=5)
return result.status_code != 401
def get_public_config(self):
return {
'CLIENT_ID': self.client_id(),
'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url()
}
def get_login_service_id(self, user_info):
return user_info['id']
def get_login_service_username(self, user_info):
return _get_email_username(user_info['email'])
def get_verified_user_email(self, app_config, http_client, token, user_info):
if not user_info.get('verified_email', False):
return None
return user_info['email']
def service_verify_user_info_for_login(self, app_config, http_client, token, user_info):
# Nothing to do.
pass

View file

@ -0,0 +1,38 @@
import pytest
from oauth.services.github import GithubOAuthService
@pytest.mark.parametrize('trigger_config, domain, api_endpoint, is_enterprise', [
({
'CLIENT_ID': 'someclientid',
'CLIENT_SECRET': 'someclientsecret',
'API_ENDPOINT': 'https://api.github.com/v3',
}, 'https://github.com', 'https://api.github.com/v3', False),
({
'GITHUB_ENDPOINT': 'https://github.somedomain.com/',
'CLIENT_ID': 'someclientid',
'CLIENT_SECRET': 'someclientsecret',
}, 'https://github.somedomain.com', 'https://github.somedomain.com/api/v3', True),
({
'GITHUB_ENDPOINT': 'https://github.somedomain.com/',
'API_ENDPOINT': 'http://somedomain.com/api/',
'CLIENT_ID': 'someclientid',
'CLIENT_SECRET': 'someclientsecret',
}, 'https://github.somedomain.com', 'http://somedomain.com/api', True),
])
def test_basic_enterprise_config(trigger_config, domain, api_endpoint, is_enterprise):
config = {
'GITHUB_TRIGGER_CONFIG': trigger_config
}
github_trigger = GithubOAuthService(config, 'GITHUB_TRIGGER_CONFIG')
assert github_trigger.is_enterprise() == is_enterprise
assert github_trigger.authorize_endpoint().to_url() == '%s/login/oauth/authorize' % domain
assert github_trigger.token_endpoint().to_url() == '%s/login/oauth/access_token' % domain
assert github_trigger.api_endpoint() == api_endpoint
assert github_trigger.user_endpoint().to_url() == '%s/user' % api_endpoint
assert github_trigger.email_endpoint() == '%s/user/emails' % api_endpoint
assert github_trigger.orgs_endpoint() == '%s/user/orgs' % api_endpoint

0
oauth/test/__init__.py Normal file
View file

View file

@ -0,0 +1,62 @@
from oauth.loginmanager import OAuthLoginManager
from oauth.services.github import GithubOAuthService
from oauth.services.google import GoogleOAuthService
from oauth.oidc import OIDCLoginService
def test_login_manager_github():
config = {
'FEATURE_GITHUB_LOGIN': True,
'GITHUB_LOGIN_CONFIG': {},
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 1
assert isinstance(loginmanager.services[0], GithubOAuthService)
def test_github_disabled():
config = {
'GITHUB_LOGIN_CONFIG': {},
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 0
def test_login_manager_google():
config = {
'FEATURE_GOOGLE_LOGIN': True,
'GOOGLE_LOGIN_CONFIG': {},
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 1
assert isinstance(loginmanager.services[0], GoogleOAuthService)
def test_google_disabled():
config = {
'GOOGLE_LOGIN_CONFIG': {},
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 0
def test_oidc():
config = {
'SOMECOOL_LOGIN_CONFIG': {},
'HTTPCLIENT': None,
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 1
assert isinstance(loginmanager.services[0], OIDCLoginService)
def test_multiple_oidc():
config = {
'SOMECOOL_LOGIN_CONFIG': {},
'ANOTHER_LOGIN_CONFIG': {},
'HTTPCLIENT': None,
}
loginmanager = OAuthLoginManager(config)
assert len(loginmanager.services) == 2
assert isinstance(loginmanager.services[0], OIDCLoginService)
assert isinstance(loginmanager.services[1], OIDCLoginService)

342
oauth/test/test_oidc.py Normal file
View file

@ -0,0 +1,342 @@
# pylint: disable=redefined-outer-name, unused-argument, invalid-name, missing-docstring, too-many-arguments
import json
import time
import urlparse
import jwt
import pytest
import requests
from httmock import urlmatch, HTTMock
from Crypto.PublicKey import RSA
from jwkest.jwk import RSAKey
from oauth.oidc import OIDCLoginService, OAuthLoginException
from util.config import URLSchemeAndHostname
@pytest.fixture(scope='module') # Slow to generate, only do it once.
def signing_key():
private_key = RSA.generate(2048)
jwk = RSAKey(key=private_key.publickey()).serialize()
return {
'id': 'somekey',
'private_key': private_key.exportKey('PEM'),
'jwk': jwk,
}
@pytest.fixture(scope="module")
def http_client():
sess = requests.Session()
adapter = requests.adapters.HTTPAdapter(pool_connections=100,
pool_maxsize=100)
sess.mount('http://', adapter)
sess.mount('https://', adapter)
return sess
@pytest.fixture(scope="module")
def valid_code():
return 'validcode'
@pytest.fixture(params=[True, False])
def mailing_feature(request):
return request.param
@pytest.fixture(params=[True, False])
def email_verified(request):
return request.param
@pytest.fixture(params=[True, False])
def userinfo_supported(request):
return request.param
@pytest.fixture(params=["someusername", "foo@bar.com", None])
def preferred_username(request):
return request.param
@pytest.fixture()
def app_config(http_client, mailing_feature):
return {
'PREFERRED_URL_SCHEME': 'http',
'SERVER_HOSTNAME': 'localhost',
'FEATURE_MAILING': mailing_feature,
'SOMEOIDC_LOGIN_CONFIG': {
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
'SERVICE_NAME': 'Some Cool Service',
'SERVICE_ICON': 'http://some/icon',
'OIDC_SERVER': 'http://fakeoidc',
'DEBUGGING': True,
},
'ANOTHEROIDC_LOGIN_CONFIG': {
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
'SERVICE_NAME': 'Some Other Service',
'SERVICE_ICON': 'http://some/icon',
'OIDC_SERVER': 'http://fakeoidc',
'LOGIN_SCOPES': ['openid'],
'DEBUGGING': True,
},
'OIDCWITHPARAMS_LOGIN_CONFIG': {
'CLIENT_ID': 'foo',
'CLIENT_SECRET': 'bar',
'SERVICE_NAME': 'Some Other Service',
'SERVICE_ICON': 'http://some/icon',
'OIDC_SERVER': 'http://fakeoidc',
'DEBUGGING': True,
'OIDC_ENDPOINT_CUSTOM_PARAMS': {
'authorization_endpoint': {
'some': 'param',
},
},
},
'HTTPCLIENT': http_client,
}
@pytest.fixture()
def oidc_service(app_config):
return OIDCLoginService(app_config, 'SOMEOIDC_LOGIN_CONFIG')
@pytest.fixture()
def another_oidc_service(app_config):
return OIDCLoginService(app_config, 'ANOTHEROIDC_LOGIN_CONFIG')
@pytest.fixture()
def oidc_withparams_service(app_config):
return OIDCLoginService(app_config, 'OIDCWITHPARAMS_LOGIN_CONFIG')
@pytest.fixture()
def discovery_content(userinfo_supported):
return {
'scopes_supported': ['openid', 'profile', 'somescope'],
'authorization_endpoint': 'http://fakeoidc/authorize',
'token_endpoint': 'http://fakeoidc/token',
'userinfo_endpoint': 'http://fakeoidc/userinfo' if userinfo_supported else None,
'jwks_uri': 'http://fakeoidc/jwks',
}
@pytest.fixture()
def userinfo_content(preferred_username, email_verified):
return {
'sub': 'cooluser',
'preferred_username': preferred_username,
'email': 'foo@example.com',
'email_verified': email_verified,
}
@pytest.fixture()
def id_token(oidc_service, signing_key, userinfo_content, app_config):
token_data = {
'iss': oidc_service.config['OIDC_SERVER'],
'aud': oidc_service.client_id(),
'nbf': int(time.time()),
'iat': int(time.time()),
'exp': int(time.time() + 600),
'sub': 'cooluser',
}
token_data.update(userinfo_content)
token_headers = {
'kid': signing_key['id'],
}
return jwt.encode(token_data, signing_key['private_key'], 'RS256', headers=token_headers)
@pytest.fixture()
def discovery_handler(discovery_content):
@urlmatch(netloc=r'fakeoidc', path=r'.+openid.+')
def handler(_, __):
return json.dumps(discovery_content)
return handler
@pytest.fixture()
def authorize_handler(discovery_content):
@urlmatch(netloc=r'fakeoidc', path=r'/authorize')
def handler(_, request):
parsed = urlparse.urlparse(request.url)
params = urlparse.parse_qs(parsed.query)
return json.dumps({'authorized': True, 'scope': params['scope'][0], 'state': params['state'][0]})
return handler
@pytest.fixture()
def token_handler(oidc_service, id_token, valid_code):
@urlmatch(netloc=r'fakeoidc', path=r'/token')
def handler(_, request):
params = urlparse.parse_qs(request.body)
if params.get('redirect_uri')[0] != 'http://localhost/oauth2/someoidc/callback':
return {'status_code': 400, 'content': 'Invalid redirect URI'}
if params.get('client_id')[0] != oidc_service.client_id():
return {'status_code': 401, 'content': 'Invalid client id'}
if params.get('client_secret')[0] != oidc_service.client_secret():
return {'status_code': 401, 'content': 'Invalid client secret'}
if params.get('code')[0] != valid_code:
return {'status_code': 401, 'content': 'Invalid code'}
if params.get('grant_type')[0] != 'authorization_code':
return {'status_code': 400, 'content': 'Invalid authorization type'}
content = {
'access_token': 'sometoken',
'id_token': id_token,
}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
@pytest.fixture()
def jwks_handler(signing_key):
def jwk_with_kid(kid, jwk):
jwk = jwk.copy()
jwk.update({'kid': kid})
return jwk
@urlmatch(netloc=r'fakeoidc', path=r'/jwks')
def handler(_, __):
content = {'keys': [jwk_with_kid(signing_key['id'], signing_key['jwk'])]}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
@pytest.fixture()
def emptykeys_jwks_handler():
@urlmatch(netloc=r'fakeoidc', path=r'/jwks')
def handler(_, __):
content = {'keys': []}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
@pytest.fixture
def userinfo_handler(oidc_service, userinfo_content):
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
def handler(_, req):
if req.headers.get('Authorization') != 'Bearer sometoken':
return {'status_code': 401, 'content': 'Missing expected header'}
return {'status_code': 200, 'content': json.dumps(userinfo_content)}
return handler
@pytest.fixture()
def invalidsub_userinfo_handler(oidc_service):
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
def handler(_, __):
content = {
'sub': 'invalidsub',
}
return {'status_code': 200, 'content': json.dumps(content)}
return handler
def test_basic_config(oidc_service):
assert oidc_service.service_id() == 'someoidc'
assert oidc_service.service_name() == 'Some Cool Service'
assert oidc_service.get_icon() == 'http://some/icon'
def test_discovery(oidc_service, http_client, discovery_content, discovery_handler):
with HTTMock(discovery_handler):
auth = discovery_content['authorization_endpoint'] + '?response_type=code'
assert oidc_service.authorize_endpoint().to_url() == auth
assert oidc_service.token_endpoint().to_url() == discovery_content['token_endpoint']
if discovery_content['userinfo_endpoint'] is None:
assert oidc_service.user_endpoint() is None
else:
assert oidc_service.user_endpoint().to_url() == discovery_content['userinfo_endpoint']
assert set(oidc_service.get_login_scopes()) == set(discovery_content['scopes_supported'])
def test_discovery_with_params(oidc_withparams_service, http_client, discovery_content, discovery_handler):
with HTTMock(discovery_handler):
assert 'some=param' in oidc_withparams_service.authorize_endpoint().to_url()
def test_filtered_discovery(another_oidc_service, http_client, discovery_content, discovery_handler):
with HTTMock(discovery_handler):
assert another_oidc_service.get_login_scopes() == ['openid']
def test_public_config(oidc_service, discovery_handler):
with HTTMock(discovery_handler):
assert oidc_service.get_public_config()['OIDC']
assert oidc_service.get_public_config()['CLIENT_ID'] == 'foo'
assert 'CLIENT_SECRET' not in oidc_service.get_public_config()
assert 'bar' not in oidc_service.get_public_config().values()
def test_auth_url(oidc_service, discovery_handler, http_client, authorize_handler):
config = {'PREFERRED_URL_SCHEME': 'https', 'SERVER_HOSTNAME': 'someserver'}
with HTTMock(discovery_handler, authorize_handler):
url_scheme_and_hostname = URLSchemeAndHostname.from_app_config(config)
auth_url = oidc_service.get_auth_url(url_scheme_and_hostname, '', 'some csrf token', ['one', 'two'])
# Hit the URL and ensure it works.
result = http_client.get(auth_url).json()
assert result['state'] == 'some csrf token'
assert result['scope'] == 'one two'
def test_exchange_code_invalidcode(oidc_service, discovery_handler, app_config, http_client,
token_handler):
with HTTMock(token_handler, discovery_handler):
with pytest.raises(OAuthLoginException):
oidc_service.exchange_code_for_login(app_config, http_client, 'testcode', '')
def test_exchange_code_invalidsub(oidc_service, discovery_handler, app_config, http_client,
token_handler, invalidsub_userinfo_handler, jwks_handler,
valid_code, userinfo_supported):
# Skip when userinfo is not supported.
if not userinfo_supported:
return
with HTTMock(jwks_handler, token_handler, invalidsub_userinfo_handler, discovery_handler):
# Should fail because the sub of the user info doesn't match that returned by the id_token.
with pytest.raises(OAuthLoginException):
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
def test_exchange_code_missingkey(oidc_service, discovery_handler, app_config, http_client,
token_handler, userinfo_handler, emptykeys_jwks_handler,
valid_code):
with HTTMock(emptykeys_jwks_handler, token_handler, userinfo_handler, discovery_handler):
# Should fail because the key is missing.
with pytest.raises(OAuthLoginException):
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
def test_exchange_code_validcode(oidc_service, discovery_handler, app_config, http_client,
token_handler, userinfo_handler, jwks_handler, valid_code,
preferred_username, mailing_feature, email_verified):
with HTTMock(jwks_handler, token_handler, userinfo_handler, discovery_handler):
if mailing_feature and not email_verified:
# Should fail because there isn't a verified email address.
with pytest.raises(OAuthLoginException):
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
else:
# Should succeed.
lid, lusername, lemail = oidc_service.exchange_code_for_login(app_config, http_client,
valid_code, '')
assert lid == 'cooluser'
if email_verified:
assert lemail == 'foo@example.com'
else:
assert lemail is None
if preferred_username is not None:
if preferred_username.find('@') >= 0:
preferred_username = preferred_username[0:preferred_username.find('@')]
assert lusername == preferred_username
else:
assert lusername == lid