initial import for Open Source 🎉
This commit is contained in:
parent
1898c361f3
commit
9c0dd3b722
2048 changed files with 218743 additions and 0 deletions
0
oauth/__init__.py
Normal file
0
oauth/__init__.py
Normal file
193
oauth/base.py
Normal file
193
oauth/base.py
Normal file
|
@ -0,0 +1,193 @@
|
|||
import copy
|
||||
import logging
|
||||
import urllib
|
||||
import urlparse
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from six import add_metaclass
|
||||
|
||||
from util.config import URLSchemeAndHostname
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OAuthEndpoint(object):
|
||||
def __init__(self, base_url, params=None):
|
||||
self.base_url = base_url
|
||||
self.params = params or {}
|
||||
|
||||
def with_param(self, name, value):
|
||||
params_copy = copy.copy(self.params)
|
||||
params_copy[name] = value
|
||||
return OAuthEndpoint(self.base_url, params_copy)
|
||||
|
||||
def with_params(self, parameters):
|
||||
params_copy = copy.copy(self.params)
|
||||
params_copy.update(parameters)
|
||||
return OAuthEndpoint(self.base_url, params_copy)
|
||||
|
||||
def to_url(self):
|
||||
(scheme, netloc, path, _, fragment) = urlparse.urlsplit(self.base_url)
|
||||
updated_query = urllib.urlencode(self.params)
|
||||
return urlparse.urlunsplit((scheme, netloc, path, updated_query, fragment))
|
||||
|
||||
class OAuthExchangeCodeException(Exception):
|
||||
""" Exception raised if a code exchange fails. """
|
||||
pass
|
||||
|
||||
class OAuthGetUserInfoException(Exception):
|
||||
""" Exception raised if a call to get user information fails. """
|
||||
pass
|
||||
|
||||
@add_metaclass(ABCMeta)
|
||||
class OAuthService(object):
|
||||
""" A base class for defining an external service, exposed via OAuth. """
|
||||
def __init__(self, config, key_name):
|
||||
self.key_name = key_name
|
||||
self.config = config.get(key_name) or {}
|
||||
|
||||
@abstractmethod
|
||||
def service_id(self):
|
||||
""" The internal ID for this service. Must match the URL portion for the service, e.g. `github`
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def service_name(self):
|
||||
""" The user-readable name for the service, e.g. `GitHub`"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def token_endpoint(self):
|
||||
""" Returns the endpoint at which the OAuth code can be exchanged for a token. """
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def user_endpoint(self):
|
||||
""" Returns the endpoint at which user information can be looked up. """
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def authorize_endpoint(self):
|
||||
""" Returns the for authorization of the OAuth service. """
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
|
||||
""" Performs validation of the client ID and secret, raising an exception on failure. """
|
||||
pass
|
||||
|
||||
def requires_form_encoding(self):
|
||||
""" Returns True if form encoding is necessary for the exchange_code_for_token call. """
|
||||
return False
|
||||
|
||||
def client_id(self):
|
||||
return self.config.get('CLIENT_ID')
|
||||
|
||||
def client_secret(self):
|
||||
return self.config.get('CLIENT_SECRET')
|
||||
|
||||
def login_binding_field(self):
|
||||
""" Returns the name of the field (`username` or `email`) used for auto binding an external
|
||||
login service account to an *internal* login service account. For example, if the external
|
||||
login service is GitHub and the internal login service is LDAP, a value of `email` here
|
||||
will cause login-with-Github to conduct a search (via email) in LDAP for a user, an auto
|
||||
bind the external and internal users together. May return None, in which case no binding
|
||||
is performing, and login with this external account will simply create a new account in the
|
||||
database.
|
||||
"""
|
||||
return self.config.get('LOGIN_BINDING_FIELD', None)
|
||||
|
||||
def get_auth_url(self, url_scheme_and_hostname, redirect_suffix, csrf_token, scopes):
|
||||
""" Retrieves the authorization URL for this login service. """
|
||||
redirect_uri = '%s/oauth2/%s/callback%s' % (url_scheme_and_hostname.get_url(),
|
||||
self.service_id(),
|
||||
redirect_suffix)
|
||||
params = {
|
||||
'client_id': self.client_id(),
|
||||
'redirect_uri': redirect_uri,
|
||||
'scope': ' '.join(scopes),
|
||||
'state': csrf_token,
|
||||
}
|
||||
|
||||
return self.authorize_endpoint().with_params(params).to_url()
|
||||
|
||||
def get_redirect_uri(self, url_scheme_and_hostname, redirect_suffix=''):
|
||||
return '%s://%s/oauth2/%s/callback%s' % (url_scheme_and_hostname.url_scheme,
|
||||
url_scheme_and_hostname.hostname,
|
||||
self.service_id(),
|
||||
redirect_suffix)
|
||||
|
||||
def get_user_info(self, http_client, token):
|
||||
token_param = {
|
||||
'alt': 'json',
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Authorization': 'Bearer %s' % token,
|
||||
}
|
||||
|
||||
got_user = http_client.get(self.user_endpoint().to_url(), params=token_param, headers=headers)
|
||||
if got_user.status_code // 100 != 2:
|
||||
raise OAuthGetUserInfoException('Non-2XX response code for user_info call: %s' %
|
||||
got_user.status_code)
|
||||
|
||||
user_info = got_user.json()
|
||||
if user_info is None:
|
||||
raise OAuthGetUserInfoException()
|
||||
|
||||
return user_info
|
||||
|
||||
def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
|
||||
redirect_suffix='', client_auth=False):
|
||||
""" Exchanges an OAuth access code for the associated OAuth token. """
|
||||
json_data = self.exchange_code(app_config, http_client, code, form_encode, redirect_suffix,
|
||||
client_auth)
|
||||
|
||||
access_token = json_data.get('access_token', None)
|
||||
if access_token is None:
|
||||
logger.debug('Got successful get_access_token response with missing token: %s', json_data)
|
||||
raise OAuthExchangeCodeException('Missing `access_token` in OAuth response')
|
||||
|
||||
return access_token
|
||||
|
||||
def exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='',
|
||||
client_auth=False):
|
||||
""" Exchanges an OAuth access code for associated OAuth token and other data. """
|
||||
url_scheme_and_hostname = URLSchemeAndHostname.from_app_config(app_config)
|
||||
payload = {
|
||||
'code': code,
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': self.get_redirect_uri(url_scheme_and_hostname, redirect_suffix)
|
||||
}
|
||||
|
||||
headers = {
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
auth = None
|
||||
if client_auth:
|
||||
auth = (self.client_id(), self.client_secret())
|
||||
else:
|
||||
payload['client_id'] = self.client_id()
|
||||
payload['client_secret'] = self.client_secret()
|
||||
|
||||
token_url = self.token_endpoint().to_url()
|
||||
if form_encode:
|
||||
get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
|
||||
else:
|
||||
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
|
||||
|
||||
if get_access_token.status_code // 100 != 2:
|
||||
logger.debug('Got get_access_token response %s', get_access_token.text)
|
||||
raise OAuthExchangeCodeException('Got non-2XX response for code exchange: %s' %
|
||||
get_access_token.status_code)
|
||||
|
||||
json_data = get_access_token.json()
|
||||
if not json_data:
|
||||
raise OAuthExchangeCodeException('Got non-JSON response for code exchange')
|
||||
|
||||
if 'error' in json_data:
|
||||
raise OAuthExchangeCodeException(json_data.get('error_description', json_data['error']))
|
||||
|
||||
return json_data
|
96
oauth/login.py
Normal file
96
oauth/login.py
Normal file
|
@ -0,0 +1,96 @@
|
|||
import logging
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from six import add_metaclass
|
||||
|
||||
import features
|
||||
|
||||
from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class OAuthLoginException(Exception):
|
||||
""" Exception raised if a login operation fails. """
|
||||
pass
|
||||
|
||||
|
||||
@add_metaclass(ABCMeta)
|
||||
class OAuthLoginService(OAuthService):
|
||||
""" A base class for defining an OAuth-compliant service that can be used for, amongst other
|
||||
things, login and authentication. """
|
||||
|
||||
@abstractmethod
|
||||
def login_enabled(self):
|
||||
""" Returns true if the login service is enabled. """
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_login_service_id(self, user_info):
|
||||
""" Returns the internal ID for the given user under this login service. """
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_login_service_username(self, user_info):
|
||||
""" Returns the username for the given user under this login service. """
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_verified_user_email(self, app_config, http_client, token, user_info):
|
||||
""" Returns the verified email address for the given user, if any or None if none. """
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_icon(self):
|
||||
""" Returns the icon to display for this login service. """
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_login_scopes(self):
|
||||
""" Returns the list of scopes for login for this service. """
|
||||
pass
|
||||
|
||||
def service_verify_user_info_for_login(self, app_config, http_client, token, user_info):
|
||||
""" Performs service-specific verification of user information for login. On failure, a service
|
||||
should raise a OAuthLoginService.
|
||||
"""
|
||||
# By default, does nothing.
|
||||
pass
|
||||
|
||||
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
|
||||
""" Exchanges the given OAuth access code for user information on behalf of a user trying to
|
||||
login or attach their account. Raises a OAuthLoginService exception on failure. Returns
|
||||
a tuple consisting of (service_id, service_username, email)
|
||||
"""
|
||||
|
||||
# Retrieve the token for the OAuth code.
|
||||
try:
|
||||
token = self.exchange_code_for_token(app_config, http_client, code,
|
||||
redirect_suffix=redirect_suffix,
|
||||
form_encode=self.requires_form_encoding())
|
||||
except OAuthExchangeCodeException as oce:
|
||||
raise OAuthLoginException(str(oce))
|
||||
|
||||
# Retrieve the user's information with the token.
|
||||
try:
|
||||
user_info = self.get_user_info(http_client, token)
|
||||
except OAuthGetUserInfoException as oge:
|
||||
raise OAuthLoginException(str(oge))
|
||||
|
||||
if user_info.get('id', None) is None:
|
||||
logger.debug('Got user info response %s', user_info)
|
||||
raise OAuthLoginException('Missing `id` column in returned user information')
|
||||
|
||||
# Perform any custom verification for this login service.
|
||||
self.service_verify_user_info_for_login(app_config, http_client, token, user_info)
|
||||
|
||||
# Retrieve the user's email address (if necessary).
|
||||
email_address = self.get_verified_user_email(app_config, http_client, token, user_info)
|
||||
if features.MAILING and email_address is None:
|
||||
raise OAuthLoginException('A verified email address is required to login with this service')
|
||||
|
||||
service_user_id = self.get_login_service_id(user_info)
|
||||
service_username = self.get_login_service_username(user_info)
|
||||
|
||||
logger.debug('Completed successful exchange for service %s: %s, %s, %s',
|
||||
self.service_id(), service_user_id, service_username, email_address)
|
||||
return (service_user_id, service_username, email_address)
|
37
oauth/loginmanager.py
Normal file
37
oauth/loginmanager.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
from oauth.services.github import GithubOAuthService
|
||||
from oauth.services.google import GoogleOAuthService
|
||||
from oauth.oidc import OIDCLoginService
|
||||
|
||||
CUSTOM_LOGIN_SERVICES = {
|
||||
'GITHUB_LOGIN_CONFIG': GithubOAuthService,
|
||||
'GOOGLE_LOGIN_CONFIG': GoogleOAuthService,
|
||||
}
|
||||
|
||||
PREFIX_BLACKLIST = ['ldap', 'jwt', 'keystone']
|
||||
|
||||
class OAuthLoginManager(object):
|
||||
""" Helper class which manages all registered OAuth login services. """
|
||||
def __init__(self, config, client=None):
|
||||
self.services = []
|
||||
|
||||
# Register the endpoints for each of the OAuth login services.
|
||||
for key in config.keys():
|
||||
# All keys which end in _LOGIN_CONFIG setup a login service.
|
||||
if key.endswith('_LOGIN_CONFIG'):
|
||||
if key in CUSTOM_LOGIN_SERVICES:
|
||||
custom_service = CUSTOM_LOGIN_SERVICES[key](config, key)
|
||||
if custom_service.login_enabled(config):
|
||||
self.services.append(custom_service)
|
||||
else:
|
||||
prefix = key.rstrip('_LOGIN_CONFIG').lower()
|
||||
if prefix in PREFIX_BLACKLIST:
|
||||
raise Exception('Cannot use reserved config name %s' % key)
|
||||
|
||||
self.services.append(OIDCLoginService(config, key, client=client))
|
||||
|
||||
def get_service(self, service_id):
|
||||
for service in self.services:
|
||||
if service.service_id() == service_id:
|
||||
return service
|
||||
|
||||
return None
|
328
oauth/oidc.py
Normal file
328
oauth/oidc.py
Normal file
|
@ -0,0 +1,328 @@
|
|||
import time
|
||||
import json
|
||||
import logging
|
||||
import urlparse
|
||||
|
||||
import jwt
|
||||
|
||||
from cachetools.func import lru_cache
|
||||
from cachetools.ttl import TTLCache
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
from jwkest.jwk import KEYS
|
||||
|
||||
from oauth.base import (OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException,
|
||||
OAuthEndpoint)
|
||||
from oauth.login import OAuthLoginException
|
||||
from util.security.jwtutil import decode, InvalidTokenError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OIDC_WELLKNOWN = ".well-known/openid-configuration"
|
||||
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
|
||||
ALLOWED_ALGORITHMS = ['RS256']
|
||||
JWT_CLOCK_SKEW_SECONDS = 30
|
||||
|
||||
class DiscoveryFailureException(Exception):
|
||||
""" Exception raised when OIDC discovery fails. """
|
||||
pass
|
||||
|
||||
class PublicKeyLoadException(Exception):
|
||||
""" Exception raised if loading the OIDC public key fails. """
|
||||
pass
|
||||
|
||||
|
||||
class OIDCLoginService(OAuthService):
|
||||
""" Defines a generic service for all OpenID-connect compatible login services. """
|
||||
def __init__(self, config, key_name, client=None):
|
||||
super(OIDCLoginService, self).__init__(config, key_name)
|
||||
|
||||
self._id = key_name[0:key_name.find('_')].lower()
|
||||
self._http_client = client or config.get('HTTPCLIENT')
|
||||
self._mailing = config.get('FEATURE_MAILING', False)
|
||||
self._public_key_cache = _PublicKeyCache(self, 1, PUBLIC_KEY_CACHE_TTL)
|
||||
|
||||
def service_id(self):
|
||||
return self._id
|
||||
|
||||
def service_name(self):
|
||||
return self.config.get('SERVICE_NAME', self.service_id())
|
||||
|
||||
def get_icon(self):
|
||||
return self.config.get('SERVICE_ICON', 'fa-user-circle')
|
||||
|
||||
def get_login_scopes(self):
|
||||
default_scopes = ['openid']
|
||||
|
||||
if self.user_endpoint() is not None:
|
||||
default_scopes.append('profile')
|
||||
|
||||
if self._mailing:
|
||||
default_scopes.append('email')
|
||||
|
||||
supported_scopes = self._oidc_config().get('scopes_supported', default_scopes)
|
||||
login_scopes = self.config.get('LOGIN_SCOPES') or supported_scopes
|
||||
return list(set(login_scopes) & set(supported_scopes))
|
||||
|
||||
def authorize_endpoint(self):
|
||||
return self._get_endpoint('authorization_endpoint').with_param('response_type', 'code')
|
||||
|
||||
def token_endpoint(self):
|
||||
return self._get_endpoint('token_endpoint')
|
||||
|
||||
def user_endpoint(self):
|
||||
return self._get_endpoint('userinfo_endpoint')
|
||||
|
||||
def _get_endpoint(self, endpoint_key, **kwargs):
|
||||
""" Returns the OIDC endpoint with the given key found in the OIDC discovery
|
||||
document, with the given kwargs added as query parameters. Additionally,
|
||||
any defined parameters found in the OIDC configuration block are also
|
||||
added.
|
||||
"""
|
||||
endpoint = self._oidc_config().get(endpoint_key, '')
|
||||
if not endpoint:
|
||||
return None
|
||||
|
||||
(scheme, netloc, path, query, fragment) = urlparse.urlsplit(endpoint)
|
||||
|
||||
# Add the query parameters from the kwargs and the config.
|
||||
custom_parameters = self.config.get('OIDC_ENDPOINT_CUSTOM_PARAMS', {}).get(endpoint_key, {})
|
||||
|
||||
query_params = urlparse.parse_qs(query, keep_blank_values=True)
|
||||
query_params.update(kwargs)
|
||||
query_params.update(custom_parameters)
|
||||
return OAuthEndpoint(urlparse.urlunsplit((scheme, netloc, path, {}, fragment)), query_params)
|
||||
|
||||
def validate(self):
|
||||
return bool(self.get_login_scopes())
|
||||
|
||||
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
|
||||
# TODO: find a way to verify client secret too.
|
||||
check_auth_url = http_client.get(self.get_auth_url(url_scheme_and_hostname, '', '', []))
|
||||
if check_auth_url.status_code // 100 != 2:
|
||||
raise Exception('Got non-200 status code for authorization endpoint')
|
||||
|
||||
def requires_form_encoding(self):
|
||||
return True
|
||||
|
||||
def get_public_config(self):
|
||||
return {
|
||||
'CLIENT_ID': self.client_id(),
|
||||
'OIDC': True,
|
||||
}
|
||||
|
||||
def exchange_code_for_tokens(self, app_config, http_client, code, redirect_suffix):
|
||||
# Exchange the code for the access token and id_token
|
||||
try:
|
||||
json_data = self.exchange_code(app_config, http_client, code,
|
||||
redirect_suffix=redirect_suffix,
|
||||
form_encode=self.requires_form_encoding())
|
||||
except OAuthExchangeCodeException as oce:
|
||||
raise OAuthLoginException(str(oce))
|
||||
|
||||
# Make sure we received both.
|
||||
access_token = json_data.get('access_token', None)
|
||||
if access_token is None:
|
||||
logger.debug('Missing access_token in response: %s', json_data)
|
||||
raise OAuthLoginException('Missing `access_token` in OIDC response')
|
||||
|
||||
id_token = json_data.get('id_token', None)
|
||||
if id_token is None:
|
||||
logger.debug('Missing id_token in response: %s', json_data)
|
||||
raise OAuthLoginException('Missing `id_token` in OIDC response')
|
||||
|
||||
return id_token, access_token
|
||||
|
||||
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
|
||||
# Exchange the code for the access token and id_token
|
||||
id_token, access_token = self.exchange_code_for_tokens(app_config, http_client, code,
|
||||
redirect_suffix)
|
||||
|
||||
# Decode the id_token.
|
||||
try:
|
||||
decoded_id_token = self.decode_user_jwt(id_token)
|
||||
except InvalidTokenError as ite:
|
||||
logger.exception('Got invalid token error on OIDC decode: %s', ite)
|
||||
raise OAuthLoginException('Could not decode OIDC token')
|
||||
except PublicKeyLoadException as pke:
|
||||
logger.exception('Could not load public key during OIDC decode: %s', pke)
|
||||
raise OAuthLoginException('Could find public OIDC key')
|
||||
|
||||
# If there is a user endpoint, use it to retrieve the user's information. Otherwise, we use
|
||||
# the decoded ID token.
|
||||
if self.user_endpoint():
|
||||
# Retrieve the user information.
|
||||
try:
|
||||
user_info = self.get_user_info(http_client, access_token)
|
||||
except OAuthGetUserInfoException as oge:
|
||||
raise OAuthLoginException(str(oge))
|
||||
else:
|
||||
user_info = decoded_id_token
|
||||
|
||||
# Verify subs.
|
||||
if user_info['sub'] != decoded_id_token['sub']:
|
||||
logger.debug('Mismatch in `sub` returned by OIDC user info endpoint: %s vs %s',
|
||||
user_info['sub'], decoded_id_token['sub'])
|
||||
raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint')
|
||||
|
||||
# Check if we have a verified email address.
|
||||
if self.config.get('VERIFIED_EMAIL_CLAIM_NAME'):
|
||||
email_address = user_info.get(self.config['VERIFIED_EMAIL_CLAIM_NAME'])
|
||||
else:
|
||||
email_address = user_info.get('email') if user_info.get('email_verified') else None
|
||||
|
||||
logger.debug('Found e-mail address `%s` for sub `%s`', email_address, user_info['sub'])
|
||||
if self._mailing:
|
||||
if email_address is None:
|
||||
raise OAuthLoginException('A verified email address is required to login with this service')
|
||||
|
||||
# Check for a preferred username.
|
||||
if self.config.get('PREFERRED_USERNAME_CLAIM_NAME'):
|
||||
lusername = user_info.get(self.config['PREFERRED_USERNAME_CLAIM_NAME'])
|
||||
else:
|
||||
lusername = user_info.get('preferred_username')
|
||||
if lusername is None:
|
||||
# Note: Active Directory provides `unique_name` and `upn`.
|
||||
# https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-id-and-access-tokens
|
||||
lusername = user_info.get('unique_name', user_info.get('upn'))
|
||||
|
||||
if lusername is None:
|
||||
lusername = user_info['sub']
|
||||
|
||||
if lusername.find('@') >= 0:
|
||||
lusername = lusername[0:lusername.find('@')]
|
||||
|
||||
return decoded_id_token['sub'], lusername, email_address
|
||||
|
||||
@property
|
||||
def _issuer(self):
|
||||
# Read the issuer from the OIDC config, falling back to the configured OIDC server.
|
||||
issuer = self._oidc_config().get('issuer', self.config['OIDC_SERVER'])
|
||||
|
||||
# If specified, use the overridden OIDC issuer.
|
||||
return self.config.get('OIDC_ISSUER', issuer)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _oidc_config(self):
|
||||
if self.config.get('OIDC_SERVER'):
|
||||
return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False))
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _load_oidc_config_via_discovery(self, is_debugging):
|
||||
""" Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True,
|
||||
non-secure connections are alllowed. Raises an DiscoveryFailureException on failure.
|
||||
"""
|
||||
oidc_server = self.config['OIDC_SERVER']
|
||||
if not oidc_server.startswith('https://') and not is_debugging:
|
||||
raise DiscoveryFailureException('OIDC server must be accessed over SSL')
|
||||
|
||||
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
|
||||
discovery = self._http_client.get(discovery_url, timeout=5, verify=is_debugging is False)
|
||||
if discovery.status_code // 100 != 2:
|
||||
logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text)
|
||||
raise DiscoveryFailureException("Could not load OIDC discovery information")
|
||||
|
||||
try:
|
||||
return json.loads(discovery.text)
|
||||
except ValueError:
|
||||
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
||||
raise DiscoveryFailureException("Could not parse OIDC discovery information")
|
||||
|
||||
def decode_user_jwt(self, token):
|
||||
""" Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError
|
||||
exception on an invalid token or a PublicKeyLoadException if the public key could not be
|
||||
loaded for decoding.
|
||||
"""
|
||||
# Find the key to use.
|
||||
headers = jwt.get_unverified_header(token)
|
||||
kid = headers.get('kid', None)
|
||||
if kid is None:
|
||||
raise InvalidTokenError('Missing `kid` header')
|
||||
|
||||
logger.debug('Using key `%s`, attempting to decode token `%s` with aud `%s` and iss `%s`',
|
||||
kid, token, self.client_id(), self._issuer)
|
||||
try:
|
||||
return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS,
|
||||
audience=self.client_id(),
|
||||
issuer=self._issuer,
|
||||
leeway=JWT_CLOCK_SKEW_SECONDS,
|
||||
options=dict(require_nbf=False))
|
||||
except InvalidTokenError as ite:
|
||||
logger.warning('Could not decode token `%s` for OIDC: %s. Will attempt again after ' +
|
||||
'retrieving public keys.', token, ite)
|
||||
|
||||
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
|
||||
try:
|
||||
return decode(token, self._get_public_key(kid, force_refresh=True),
|
||||
algorithms=ALLOWED_ALGORITHMS,
|
||||
audience=self.client_id(),
|
||||
issuer=self._issuer,
|
||||
leeway=JWT_CLOCK_SKEW_SECONDS,
|
||||
options=dict(require_nbf=False))
|
||||
except InvalidTokenError as ite:
|
||||
logger.warning('Could not decode token `%s` for OIDC: %s. Attempted again after ' +
|
||||
'retrieving public keys.', token, ite)
|
||||
|
||||
# Decode again with verify=False, and log the decoded token to allow for easier debugging.
|
||||
nonverified = decode(token, self._get_public_key(kid, force_refresh=True),
|
||||
algorithms=ALLOWED_ALGORITHMS,
|
||||
audience=self.client_id(),
|
||||
issuer=self._issuer,
|
||||
leeway=JWT_CLOCK_SKEW_SECONDS,
|
||||
options=dict(require_nbf=False, verify=False))
|
||||
logger.debug('Got an error when trying to verify OIDC JWT: %s', nonverified)
|
||||
raise ite
|
||||
|
||||
def _get_public_key(self, kid, force_refresh=False):
|
||||
""" Retrieves the public key for this handler with the given kid. Raises a
|
||||
PublicKeyLoadException on failure. """
|
||||
|
||||
# If force_refresh is true, we expire all the items in the cache by setting the time to
|
||||
# the current time + the expiration TTL.
|
||||
if force_refresh:
|
||||
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
|
||||
|
||||
# Retrieve the public key from the cache. If the cache does not contain the public key,
|
||||
# it will internally call _load_public_key to retrieve it and then save it.
|
||||
return self._public_key_cache[kid]
|
||||
|
||||
|
||||
class _PublicKeyCache(TTLCache):
|
||||
def __init__(self, login_service, *args, **kwargs):
|
||||
super(_PublicKeyCache, self).__init__(*args, **kwargs)
|
||||
|
||||
self._login_service = login_service
|
||||
|
||||
def __missing__(self, kid):
|
||||
""" Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException
|
||||
on failure.
|
||||
"""
|
||||
keys_url = self._login_service._oidc_config()['jwks_uri']
|
||||
|
||||
# Load the keys.
|
||||
try:
|
||||
keys = KEYS()
|
||||
keys.load_from_url(keys_url, verify=not self._login_service.config.get('DEBUGGING', False))
|
||||
except Exception as ex:
|
||||
logger.exception('Exception loading public key')
|
||||
raise PublicKeyLoadException(str(ex))
|
||||
|
||||
# Find the matching key.
|
||||
keys_found = keys.by_kid(kid)
|
||||
if len(keys_found) == 0:
|
||||
raise PublicKeyLoadException('Public key %s not found' % kid)
|
||||
|
||||
rsa_keys = [key for key in keys_found if key.kty == 'RSA']
|
||||
if len(rsa_keys) == 0:
|
||||
raise PublicKeyLoadException('No RSA form of public key %s not found' % kid)
|
||||
|
||||
matching_key = rsa_keys[0]
|
||||
matching_key.deserialize()
|
||||
|
||||
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
|
||||
# issues.
|
||||
final_key = load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend())
|
||||
self[kid] = final_key
|
||||
return final_key
|
0
oauth/services/__init__.py
Normal file
0
oauth/services/__init__.py
Normal file
180
oauth/services/github.py
Normal file
180
oauth/services/github.py
Normal file
|
@ -0,0 +1,180 @@
|
|||
import logging
|
||||
|
||||
from oauth.base import OAuthEndpoint
|
||||
from oauth.login import OAuthLoginService, OAuthLoginException
|
||||
from util import slash_join
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class GithubOAuthService(OAuthLoginService):
|
||||
def __init__(self, config, key_name):
|
||||
super(GithubOAuthService, self).__init__(config, key_name)
|
||||
|
||||
def login_enabled(self, config):
|
||||
return config.get('FEATURE_GITHUB_LOGIN', False)
|
||||
|
||||
def service_id(self):
|
||||
return 'github'
|
||||
|
||||
def service_name(self):
|
||||
if self.is_enterprise():
|
||||
return 'GitHub Enterprise'
|
||||
|
||||
return 'GitHub'
|
||||
|
||||
def get_icon(self):
|
||||
return 'fa-github'
|
||||
|
||||
def get_login_scopes(self):
|
||||
if self.config.get('ORG_RESTRICT'):
|
||||
return ['user:email', 'read:org']
|
||||
|
||||
return ['user:email']
|
||||
|
||||
def allowed_organizations(self):
|
||||
if not self.config.get('ORG_RESTRICT', False):
|
||||
return None
|
||||
|
||||
allowed = self.config.get('ALLOWED_ORGANIZATIONS', None)
|
||||
if allowed is None:
|
||||
return None
|
||||
|
||||
return [org.lower() for org in allowed]
|
||||
|
||||
def get_public_url(self, suffix):
|
||||
return slash_join(self._endpoint(), suffix)
|
||||
|
||||
def _endpoint(self):
|
||||
return self.config.get('GITHUB_ENDPOINT', 'https://github.com')
|
||||
|
||||
def is_enterprise(self):
|
||||
return self._api_endpoint().find('.github.com') < 0
|
||||
|
||||
def authorize_endpoint(self):
|
||||
return OAuthEndpoint(slash_join(self._endpoint(), '/login/oauth/authorize'))
|
||||
|
||||
def token_endpoint(self):
|
||||
return OAuthEndpoint(slash_join(self._endpoint(), '/login/oauth/access_token'))
|
||||
|
||||
def user_endpoint(self):
|
||||
return OAuthEndpoint(slash_join(self._api_endpoint(), 'user'))
|
||||
|
||||
def _api_endpoint(self):
|
||||
return self.config.get('API_ENDPOINT', slash_join(self._endpoint(), '/api/v3/'))
|
||||
|
||||
def api_endpoint(self):
|
||||
endpoint = self._api_endpoint()
|
||||
if endpoint.endswith('/'):
|
||||
return endpoint[0:-1]
|
||||
|
||||
return endpoint
|
||||
|
||||
def email_endpoint(self):
|
||||
return slash_join(self._api_endpoint(), 'user/emails')
|
||||
|
||||
def orgs_endpoint(self):
|
||||
return slash_join(self._api_endpoint(), 'user/orgs')
|
||||
|
||||
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
|
||||
# First: Verify that the github endpoint is actually Github by checking for the
|
||||
# X-GitHub-Request-Id here.
|
||||
api_endpoint = self._api_endpoint()
|
||||
result = http_client.get(api_endpoint, auth=(self.client_id(), self.client_secret()), timeout=5)
|
||||
if not 'X-GitHub-Request-Id' in result.headers:
|
||||
raise Exception('Endpoint is not a Github (Enterprise) installation')
|
||||
|
||||
# Next: Verify the client ID and secret.
|
||||
# Note: The following code is a hack until such time as Github officially adds an API endpoint
|
||||
# for verifying a {client_id, client_secret} pair. This workaround was given to us
|
||||
# *by a Github Engineer* (Jan 8, 2015).
|
||||
#
|
||||
# TODO: Replace with the real API call once added.
|
||||
#
|
||||
# Hitting the endpoint applications/{client_id}/tokens/foo will result in the following
|
||||
# behavior IF the client_id is given as the HTTP username and the client_secret as the HTTP
|
||||
# password:
|
||||
# - If the {client_id, client_secret} pair is invalid in some way, we get a 401 error.
|
||||
# - If the pair is valid, then we get a 404 because the 'foo' token does not exists.
|
||||
validate_endpoint = slash_join(api_endpoint, 'applications/%s/tokens/foo' % self.client_id())
|
||||
result = http_client.get(validate_endpoint, auth=(self.client_id(), self.client_secret()),
|
||||
timeout=5)
|
||||
return result.status_code == 404
|
||||
|
||||
def validate_organization(self, organization_id, http_client):
|
||||
org_endpoint = slash_join(self._api_endpoint(), 'orgs/%s' % organization_id.lower())
|
||||
|
||||
result = http_client.get(org_endpoint,
|
||||
headers={'Accept': 'application/vnd.github.moondragon+json'},
|
||||
timeout=5)
|
||||
|
||||
return result.status_code == 200
|
||||
|
||||
|
||||
def get_public_config(self):
|
||||
return {
|
||||
'CLIENT_ID': self.client_id(),
|
||||
'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url(),
|
||||
'GITHUB_ENDPOINT': self._endpoint(),
|
||||
'ORG_RESTRICT': self.config.get('ORG_RESTRICT', False)
|
||||
}
|
||||
|
||||
def get_login_service_id(self, user_info):
|
||||
return user_info['id']
|
||||
|
||||
def get_login_service_username(self, user_info):
|
||||
return user_info['login']
|
||||
|
||||
def get_verified_user_email(self, app_config, http_client, token, user_info):
|
||||
v3_media_type = {
|
||||
'Accept': 'application/vnd.github.v3'
|
||||
}
|
||||
|
||||
token_param = {
|
||||
'access_token': token,
|
||||
}
|
||||
|
||||
# Find the e-mail address for the user: we will accept any email, but we prefer the primary
|
||||
get_email = http_client.get(self.email_endpoint(), params=token_param, headers=v3_media_type)
|
||||
if get_email.status_code // 100 != 2:
|
||||
raise OAuthLoginException('Got non-2XX status code for emails endpoint: %s' %
|
||||
get_email.status_code)
|
||||
|
||||
verified_emails = [email for email in get_email.json() if email['verified']]
|
||||
primary_emails = [email for email in get_email.json() if email['primary']]
|
||||
|
||||
# Special case: We don't care about whether an e-mail address is "verified" under GHE.
|
||||
if self.is_enterprise() and not verified_emails:
|
||||
verified_emails = primary_emails
|
||||
|
||||
allowed_emails = (primary_emails or verified_emails or [])
|
||||
return allowed_emails[0]['email'] if len(allowed_emails) > 0 else None
|
||||
|
||||
def service_verify_user_info_for_login(self, app_config, http_client, token, user_info):
|
||||
# Retrieve the user's orgnizations (if organization filtering is turned on)
|
||||
if self.allowed_organizations() is None:
|
||||
return
|
||||
|
||||
moondragon_media_type = {
|
||||
'Accept': 'application/vnd.github.moondragon+json'
|
||||
}
|
||||
|
||||
token_param = {
|
||||
'access_token': token,
|
||||
}
|
||||
|
||||
get_orgs = http_client.get(self.orgs_endpoint(), params=token_param,
|
||||
headers=moondragon_media_type)
|
||||
|
||||
if get_orgs.status_code // 100 != 2:
|
||||
logger.debug('get_orgs response: %s', get_orgs.json())
|
||||
raise OAuthLoginException('Got non-2XX response for org lookup: %s' %
|
||||
get_orgs.status_code)
|
||||
|
||||
organizations = set([org.get('login').lower() for org in get_orgs.json()])
|
||||
matching_organizations = organizations & set(self.allowed_organizations())
|
||||
if not matching_organizations:
|
||||
logger.debug('Found organizations %s, but expected one of %s', organizations,
|
||||
self.allowed_organizations())
|
||||
err = """You are not a member of an allowed GitHub organization.
|
||||
Please contact your system administrator if you believe this is in error."""
|
||||
raise OAuthLoginException(err)
|
60
oauth/services/gitlab.py
Normal file
60
oauth/services/gitlab.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
from oauth.base import OAuthService, OAuthEndpoint
|
||||
from util import slash_join
|
||||
|
||||
class GitLabOAuthService(OAuthService):
|
||||
def __init__(self, config, key_name):
|
||||
super(GitLabOAuthService, self).__init__(config, key_name)
|
||||
|
||||
def service_id(self):
|
||||
return 'gitlab'
|
||||
|
||||
def service_name(self):
|
||||
return 'GitLab'
|
||||
|
||||
def _endpoint(self):
|
||||
return self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com')
|
||||
|
||||
def user_endpoint(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def api_endpoint(self):
|
||||
return self._endpoint()
|
||||
|
||||
def get_public_url(self, suffix):
|
||||
return slash_join(self._endpoint(), suffix)
|
||||
|
||||
def authorize_endpoint(self):
|
||||
return OAuthEndpoint(slash_join(self._endpoint(), '/oauth/authorize'))
|
||||
|
||||
def token_endpoint(self):
|
||||
return OAuthEndpoint(slash_join(self._endpoint(), '/oauth/token'))
|
||||
|
||||
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
|
||||
# We validate the client ID and secret by hitting the OAuth token exchange endpoint with
|
||||
# the real client ID and secret, but a fake auth code to exchange. Gitlab's implementation will
|
||||
# return `invalid_client` as the `error` if the client ID or secret is invalid; otherwise, it
|
||||
# will return another error.
|
||||
url = self.token_endpoint().to_url()
|
||||
redirect_uri = self.get_redirect_uri(url_scheme_and_hostname, redirect_suffix='trigger')
|
||||
data = {
|
||||
'code': 'fakecode',
|
||||
'client_id': self.client_id(),
|
||||
'client_secret': self.client_secret(),
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': redirect_uri
|
||||
}
|
||||
|
||||
# We validate by checking the error code we receive from this call.
|
||||
result = http_client.post(url, data=data, timeout=5)
|
||||
value = result.json()
|
||||
if not value:
|
||||
return False
|
||||
|
||||
return value.get('error', '') != 'invalid_client'
|
||||
|
||||
def get_public_config(self):
|
||||
return {
|
||||
'CLIENT_ID': self.client_id(),
|
||||
'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url(),
|
||||
'GITLAB_ENDPOINT': self._endpoint(),
|
||||
}
|
81
oauth/services/google.py
Normal file
81
oauth/services/google.py
Normal file
|
@ -0,0 +1,81 @@
|
|||
from oauth.base import OAuthEndpoint
|
||||
from oauth.login import OAuthLoginService
|
||||
|
||||
def _get_email_username(email_address):
|
||||
username = email_address
|
||||
at = username.find('@')
|
||||
if at > 0:
|
||||
username = username[0:at]
|
||||
|
||||
return username
|
||||
|
||||
class GoogleOAuthService(OAuthLoginService):
|
||||
def __init__(self, config, key_name):
|
||||
super(GoogleOAuthService, self).__init__(config, key_name)
|
||||
|
||||
def login_enabled(self, config):
|
||||
return config.get('FEATURE_GOOGLE_LOGIN', False)
|
||||
|
||||
def service_id(self):
|
||||
return 'google'
|
||||
|
||||
def service_name(self):
|
||||
return 'Google'
|
||||
|
||||
def get_icon(self):
|
||||
return 'fa-google'
|
||||
|
||||
def get_login_scopes(self):
|
||||
return ['openid', 'email']
|
||||
|
||||
def authorize_endpoint(self):
|
||||
return OAuthEndpoint('https://accounts.google.com/o/oauth2/auth',
|
||||
params=dict(response_type='code'))
|
||||
|
||||
def token_endpoint(self):
|
||||
return OAuthEndpoint('https://accounts.google.com/o/oauth2/token')
|
||||
|
||||
def user_endpoint(self):
|
||||
return OAuthEndpoint('https://www.googleapis.com/oauth2/v1/userinfo')
|
||||
|
||||
def requires_form_encoding(self):
|
||||
return True
|
||||
|
||||
def validate_client_id_and_secret(self, http_client, url_scheme_and_hostname):
|
||||
# To verify the Google client ID and secret, we hit the
|
||||
# https://www.googleapis.com/oauth2/v3/token endpoint with an invalid request. If the client
|
||||
# ID or secret are invalid, we get returned a 403 Unauthorized. Otherwise, we get returned
|
||||
# another response code.
|
||||
url = 'https://www.googleapis.com/oauth2/v3/token'
|
||||
data = {
|
||||
'code': 'fakecode',
|
||||
'client_id': self.client_id(),
|
||||
'client_secret': self.client_secret(),
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': 'http://example.com'
|
||||
}
|
||||
|
||||
result = http_client.post(url, data=data, timeout=5)
|
||||
return result.status_code != 401
|
||||
|
||||
def get_public_config(self):
|
||||
return {
|
||||
'CLIENT_ID': self.client_id(),
|
||||
'AUTHORIZE_ENDPOINT': self.authorize_endpoint().to_url()
|
||||
}
|
||||
|
||||
def get_login_service_id(self, user_info):
|
||||
return user_info['id']
|
||||
|
||||
def get_login_service_username(self, user_info):
|
||||
return _get_email_username(user_info['email'])
|
||||
|
||||
def get_verified_user_email(self, app_config, http_client, token, user_info):
|
||||
if not user_info.get('verified_email', False):
|
||||
return None
|
||||
|
||||
return user_info['email']
|
||||
|
||||
def service_verify_user_info_for_login(self, app_config, http_client, token, user_info):
|
||||
# Nothing to do.
|
||||
pass
|
38
oauth/services/test/test_github.py
Normal file
38
oauth/services/test/test_github.py
Normal file
|
@ -0,0 +1,38 @@
|
|||
import pytest
|
||||
|
||||
from oauth.services.github import GithubOAuthService
|
||||
|
||||
@pytest.mark.parametrize('trigger_config, domain, api_endpoint, is_enterprise', [
|
||||
({
|
||||
'CLIENT_ID': 'someclientid',
|
||||
'CLIENT_SECRET': 'someclientsecret',
|
||||
'API_ENDPOINT': 'https://api.github.com/v3',
|
||||
}, 'https://github.com', 'https://api.github.com/v3', False),
|
||||
({
|
||||
'GITHUB_ENDPOINT': 'https://github.somedomain.com/',
|
||||
'CLIENT_ID': 'someclientid',
|
||||
'CLIENT_SECRET': 'someclientsecret',
|
||||
}, 'https://github.somedomain.com', 'https://github.somedomain.com/api/v3', True),
|
||||
({
|
||||
'GITHUB_ENDPOINT': 'https://github.somedomain.com/',
|
||||
'API_ENDPOINT': 'http://somedomain.com/api/',
|
||||
'CLIENT_ID': 'someclientid',
|
||||
'CLIENT_SECRET': 'someclientsecret',
|
||||
}, 'https://github.somedomain.com', 'http://somedomain.com/api', True),
|
||||
])
|
||||
def test_basic_enterprise_config(trigger_config, domain, api_endpoint, is_enterprise):
|
||||
config = {
|
||||
'GITHUB_TRIGGER_CONFIG': trigger_config
|
||||
}
|
||||
|
||||
github_trigger = GithubOAuthService(config, 'GITHUB_TRIGGER_CONFIG')
|
||||
assert github_trigger.is_enterprise() == is_enterprise
|
||||
|
||||
assert github_trigger.authorize_endpoint().to_url() == '%s/login/oauth/authorize' % domain
|
||||
|
||||
assert github_trigger.token_endpoint().to_url() == '%s/login/oauth/access_token' % domain
|
||||
|
||||
assert github_trigger.api_endpoint() == api_endpoint
|
||||
assert github_trigger.user_endpoint().to_url() == '%s/user' % api_endpoint
|
||||
assert github_trigger.email_endpoint() == '%s/user/emails' % api_endpoint
|
||||
assert github_trigger.orgs_endpoint() == '%s/user/orgs' % api_endpoint
|
0
oauth/test/__init__.py
Normal file
0
oauth/test/__init__.py
Normal file
62
oauth/test/test_loginmanager.py
Normal file
62
oauth/test/test_loginmanager.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from oauth.loginmanager import OAuthLoginManager
|
||||
from oauth.services.github import GithubOAuthService
|
||||
from oauth.services.google import GoogleOAuthService
|
||||
from oauth.oidc import OIDCLoginService
|
||||
|
||||
def test_login_manager_github():
|
||||
config = {
|
||||
'FEATURE_GITHUB_LOGIN': True,
|
||||
'GITHUB_LOGIN_CONFIG': {},
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 1
|
||||
assert isinstance(loginmanager.services[0], GithubOAuthService)
|
||||
|
||||
def test_github_disabled():
|
||||
config = {
|
||||
'GITHUB_LOGIN_CONFIG': {},
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 0
|
||||
|
||||
def test_login_manager_google():
|
||||
config = {
|
||||
'FEATURE_GOOGLE_LOGIN': True,
|
||||
'GOOGLE_LOGIN_CONFIG': {},
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 1
|
||||
assert isinstance(loginmanager.services[0], GoogleOAuthService)
|
||||
|
||||
def test_google_disabled():
|
||||
config = {
|
||||
'GOOGLE_LOGIN_CONFIG': {},
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 0
|
||||
|
||||
def test_oidc():
|
||||
config = {
|
||||
'SOMECOOL_LOGIN_CONFIG': {},
|
||||
'HTTPCLIENT': None,
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 1
|
||||
assert isinstance(loginmanager.services[0], OIDCLoginService)
|
||||
|
||||
def test_multiple_oidc():
|
||||
config = {
|
||||
'SOMECOOL_LOGIN_CONFIG': {},
|
||||
'ANOTHER_LOGIN_CONFIG': {},
|
||||
'HTTPCLIENT': None,
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 2
|
||||
assert isinstance(loginmanager.services[0], OIDCLoginService)
|
||||
assert isinstance(loginmanager.services[1], OIDCLoginService)
|
342
oauth/test/test_oidc.py
Normal file
342
oauth/test/test_oidc.py
Normal file
|
@ -0,0 +1,342 @@
|
|||
# pylint: disable=redefined-outer-name, unused-argument, invalid-name, missing-docstring, too-many-arguments
|
||||
|
||||
import json
|
||||
import time
|
||||
import urlparse
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
from Crypto.PublicKey import RSA
|
||||
from jwkest.jwk import RSAKey
|
||||
|
||||
from oauth.oidc import OIDCLoginService, OAuthLoginException
|
||||
from util.config import URLSchemeAndHostname
|
||||
|
||||
|
||||
@pytest.fixture(scope='module') # Slow to generate, only do it once.
|
||||
def signing_key():
|
||||
private_key = RSA.generate(2048)
|
||||
jwk = RSAKey(key=private_key.publickey()).serialize()
|
||||
return {
|
||||
'id': 'somekey',
|
||||
'private_key': private_key.exportKey('PEM'),
|
||||
'jwk': jwk,
|
||||
}
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def http_client():
|
||||
sess = requests.Session()
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections=100,
|
||||
pool_maxsize=100)
|
||||
sess.mount('http://', adapter)
|
||||
sess.mount('https://', adapter)
|
||||
return sess
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def valid_code():
|
||||
return 'validcode'
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def mailing_feature(request):
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def email_verified(request):
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def userinfo_supported(request):
|
||||
return request.param
|
||||
|
||||
@pytest.fixture(params=["someusername", "foo@bar.com", None])
|
||||
def preferred_username(request):
|
||||
return request.param
|
||||
|
||||
@pytest.fixture()
|
||||
def app_config(http_client, mailing_feature):
|
||||
return {
|
||||
'PREFERRED_URL_SCHEME': 'http',
|
||||
'SERVER_HOSTNAME': 'localhost',
|
||||
'FEATURE_MAILING': mailing_feature,
|
||||
|
||||
'SOMEOIDC_LOGIN_CONFIG': {
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
'SERVICE_NAME': 'Some Cool Service',
|
||||
'SERVICE_ICON': 'http://some/icon',
|
||||
'OIDC_SERVER': 'http://fakeoidc',
|
||||
'DEBUGGING': True,
|
||||
},
|
||||
|
||||
'ANOTHEROIDC_LOGIN_CONFIG': {
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
'SERVICE_NAME': 'Some Other Service',
|
||||
'SERVICE_ICON': 'http://some/icon',
|
||||
'OIDC_SERVER': 'http://fakeoidc',
|
||||
'LOGIN_SCOPES': ['openid'],
|
||||
'DEBUGGING': True,
|
||||
},
|
||||
|
||||
'OIDCWITHPARAMS_LOGIN_CONFIG': {
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
'SERVICE_NAME': 'Some Other Service',
|
||||
'SERVICE_ICON': 'http://some/icon',
|
||||
'OIDC_SERVER': 'http://fakeoidc',
|
||||
'DEBUGGING': True,
|
||||
'OIDC_ENDPOINT_CUSTOM_PARAMS': {
|
||||
'authorization_endpoint': {
|
||||
'some': 'param',
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
'HTTPCLIENT': http_client,
|
||||
}
|
||||
|
||||
@pytest.fixture()
|
||||
def oidc_service(app_config):
|
||||
return OIDCLoginService(app_config, 'SOMEOIDC_LOGIN_CONFIG')
|
||||
|
||||
@pytest.fixture()
|
||||
def another_oidc_service(app_config):
|
||||
return OIDCLoginService(app_config, 'ANOTHEROIDC_LOGIN_CONFIG')
|
||||
|
||||
@pytest.fixture()
|
||||
def oidc_withparams_service(app_config):
|
||||
return OIDCLoginService(app_config, 'OIDCWITHPARAMS_LOGIN_CONFIG')
|
||||
|
||||
@pytest.fixture()
|
||||
def discovery_content(userinfo_supported):
|
||||
return {
|
||||
'scopes_supported': ['openid', 'profile', 'somescope'],
|
||||
'authorization_endpoint': 'http://fakeoidc/authorize',
|
||||
'token_endpoint': 'http://fakeoidc/token',
|
||||
'userinfo_endpoint': 'http://fakeoidc/userinfo' if userinfo_supported else None,
|
||||
'jwks_uri': 'http://fakeoidc/jwks',
|
||||
}
|
||||
|
||||
@pytest.fixture()
|
||||
def userinfo_content(preferred_username, email_verified):
|
||||
return {
|
||||
'sub': 'cooluser',
|
||||
'preferred_username': preferred_username,
|
||||
'email': 'foo@example.com',
|
||||
'email_verified': email_verified,
|
||||
}
|
||||
|
||||
@pytest.fixture()
|
||||
def id_token(oidc_service, signing_key, userinfo_content, app_config):
|
||||
token_data = {
|
||||
'iss': oidc_service.config['OIDC_SERVER'],
|
||||
'aud': oidc_service.client_id(),
|
||||
'nbf': int(time.time()),
|
||||
'iat': int(time.time()),
|
||||
'exp': int(time.time() + 600),
|
||||
'sub': 'cooluser',
|
||||
}
|
||||
|
||||
token_data.update(userinfo_content)
|
||||
|
||||
token_headers = {
|
||||
'kid': signing_key['id'],
|
||||
}
|
||||
|
||||
return jwt.encode(token_data, signing_key['private_key'], 'RS256', headers=token_headers)
|
||||
|
||||
@pytest.fixture()
|
||||
def discovery_handler(discovery_content):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'.+openid.+')
|
||||
def handler(_, __):
|
||||
return json.dumps(discovery_content)
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def authorize_handler(discovery_content):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/authorize')
|
||||
def handler(_, request):
|
||||
parsed = urlparse.urlparse(request.url)
|
||||
params = urlparse.parse_qs(parsed.query)
|
||||
return json.dumps({'authorized': True, 'scope': params['scope'][0], 'state': params['state'][0]})
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def token_handler(oidc_service, id_token, valid_code):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/token')
|
||||
def handler(_, request):
|
||||
params = urlparse.parse_qs(request.body)
|
||||
if params.get('redirect_uri')[0] != 'http://localhost/oauth2/someoidc/callback':
|
||||
return {'status_code': 400, 'content': 'Invalid redirect URI'}
|
||||
|
||||
if params.get('client_id')[0] != oidc_service.client_id():
|
||||
return {'status_code': 401, 'content': 'Invalid client id'}
|
||||
|
||||
if params.get('client_secret')[0] != oidc_service.client_secret():
|
||||
return {'status_code': 401, 'content': 'Invalid client secret'}
|
||||
|
||||
if params.get('code')[0] != valid_code:
|
||||
return {'status_code': 401, 'content': 'Invalid code'}
|
||||
|
||||
if params.get('grant_type')[0] != 'authorization_code':
|
||||
return {'status_code': 400, 'content': 'Invalid authorization type'}
|
||||
|
||||
content = {
|
||||
'access_token': 'sometoken',
|
||||
'id_token': id_token,
|
||||
}
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def jwks_handler(signing_key):
|
||||
def jwk_with_kid(kid, jwk):
|
||||
jwk = jwk.copy()
|
||||
jwk.update({'kid': kid})
|
||||
return jwk
|
||||
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/jwks')
|
||||
def handler(_, __):
|
||||
content = {'keys': [jwk_with_kid(signing_key['id'], signing_key['jwk'])]}
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def emptykeys_jwks_handler():
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/jwks')
|
||||
def handler(_, __):
|
||||
content = {'keys': []}
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture
|
||||
def userinfo_handler(oidc_service, userinfo_content):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
|
||||
def handler(_, req):
|
||||
if req.headers.get('Authorization') != 'Bearer sometoken':
|
||||
return {'status_code': 401, 'content': 'Missing expected header'}
|
||||
|
||||
return {'status_code': 200, 'content': json.dumps(userinfo_content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def invalidsub_userinfo_handler(oidc_service):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
|
||||
def handler(_, __):
|
||||
content = {
|
||||
'sub': 'invalidsub',
|
||||
}
|
||||
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def test_basic_config(oidc_service):
|
||||
assert oidc_service.service_id() == 'someoidc'
|
||||
assert oidc_service.service_name() == 'Some Cool Service'
|
||||
assert oidc_service.get_icon() == 'http://some/icon'
|
||||
|
||||
def test_discovery(oidc_service, http_client, discovery_content, discovery_handler):
|
||||
with HTTMock(discovery_handler):
|
||||
auth = discovery_content['authorization_endpoint'] + '?response_type=code'
|
||||
assert oidc_service.authorize_endpoint().to_url() == auth
|
||||
assert oidc_service.token_endpoint().to_url() == discovery_content['token_endpoint']
|
||||
|
||||
if discovery_content['userinfo_endpoint'] is None:
|
||||
assert oidc_service.user_endpoint() is None
|
||||
else:
|
||||
assert oidc_service.user_endpoint().to_url() == discovery_content['userinfo_endpoint']
|
||||
|
||||
assert set(oidc_service.get_login_scopes()) == set(discovery_content['scopes_supported'])
|
||||
|
||||
def test_discovery_with_params(oidc_withparams_service, http_client, discovery_content, discovery_handler):
|
||||
with HTTMock(discovery_handler):
|
||||
assert 'some=param' in oidc_withparams_service.authorize_endpoint().to_url()
|
||||
|
||||
def test_filtered_discovery(another_oidc_service, http_client, discovery_content, discovery_handler):
|
||||
with HTTMock(discovery_handler):
|
||||
assert another_oidc_service.get_login_scopes() == ['openid']
|
||||
|
||||
def test_public_config(oidc_service, discovery_handler):
|
||||
with HTTMock(discovery_handler):
|
||||
assert oidc_service.get_public_config()['OIDC']
|
||||
assert oidc_service.get_public_config()['CLIENT_ID'] == 'foo'
|
||||
|
||||
assert 'CLIENT_SECRET' not in oidc_service.get_public_config()
|
||||
assert 'bar' not in oidc_service.get_public_config().values()
|
||||
|
||||
def test_auth_url(oidc_service, discovery_handler, http_client, authorize_handler):
|
||||
config = {'PREFERRED_URL_SCHEME': 'https', 'SERVER_HOSTNAME': 'someserver'}
|
||||
|
||||
with HTTMock(discovery_handler, authorize_handler):
|
||||
url_scheme_and_hostname = URLSchemeAndHostname.from_app_config(config)
|
||||
auth_url = oidc_service.get_auth_url(url_scheme_and_hostname, '', 'some csrf token', ['one', 'two'])
|
||||
|
||||
# Hit the URL and ensure it works.
|
||||
result = http_client.get(auth_url).json()
|
||||
assert result['state'] == 'some csrf token'
|
||||
assert result['scope'] == 'one two'
|
||||
|
||||
def test_exchange_code_invalidcode(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler):
|
||||
with HTTMock(token_handler, discovery_handler):
|
||||
with pytest.raises(OAuthLoginException):
|
||||
oidc_service.exchange_code_for_login(app_config, http_client, 'testcode', '')
|
||||
|
||||
def test_exchange_code_invalidsub(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler, invalidsub_userinfo_handler, jwks_handler,
|
||||
valid_code, userinfo_supported):
|
||||
# Skip when userinfo is not supported.
|
||||
if not userinfo_supported:
|
||||
return
|
||||
|
||||
with HTTMock(jwks_handler, token_handler, invalidsub_userinfo_handler, discovery_handler):
|
||||
# Should fail because the sub of the user info doesn't match that returned by the id_token.
|
||||
with pytest.raises(OAuthLoginException):
|
||||
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
|
||||
|
||||
def test_exchange_code_missingkey(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler, userinfo_handler, emptykeys_jwks_handler,
|
||||
valid_code):
|
||||
with HTTMock(emptykeys_jwks_handler, token_handler, userinfo_handler, discovery_handler):
|
||||
# Should fail because the key is missing.
|
||||
with pytest.raises(OAuthLoginException):
|
||||
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
|
||||
|
||||
def test_exchange_code_validcode(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler, userinfo_handler, jwks_handler, valid_code,
|
||||
preferred_username, mailing_feature, email_verified):
|
||||
with HTTMock(jwks_handler, token_handler, userinfo_handler, discovery_handler):
|
||||
if mailing_feature and not email_verified:
|
||||
# Should fail because there isn't a verified email address.
|
||||
with pytest.raises(OAuthLoginException):
|
||||
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
|
||||
else:
|
||||
# Should succeed.
|
||||
lid, lusername, lemail = oidc_service.exchange_code_for_login(app_config, http_client,
|
||||
valid_code, '')
|
||||
|
||||
assert lid == 'cooluser'
|
||||
|
||||
if email_verified:
|
||||
assert lemail == 'foo@example.com'
|
||||
else:
|
||||
assert lemail is None
|
||||
|
||||
if preferred_username is not None:
|
||||
if preferred_username.find('@') >= 0:
|
||||
preferred_username = preferred_username[0:preferred_username.find('@')]
|
||||
|
||||
assert lusername == preferred_username
|
||||
else:
|
||||
assert lusername == lid
|
Reference in a new issue