Add proper and tested OIDC support on the server
Note that this will still not work on the client side; the followup CL for the client side is right after this one.
This commit is contained in:
parent
19f7acf575
commit
fda203e4d7
15 changed files with 756 additions and 180 deletions
|
@ -368,7 +368,7 @@ def update_user_metadata(user, given_name=None, family_name=None, company=None):
|
|||
remove_user_prompt(user, UserPromptTypes.ENTER_COMPANY)
|
||||
|
||||
|
||||
def create_federated_user(username, email, service_name, service_ident,
|
||||
def create_federated_user(username, email, service_id, service_ident,
|
||||
set_password_notification, metadata={},
|
||||
email_required=True, prompts=tuple()):
|
||||
prompts = set(prompts)
|
||||
|
@ -378,7 +378,11 @@ def create_federated_user(username, email, service_name, service_ident,
|
|||
new_user.verified = True
|
||||
new_user.save()
|
||||
|
||||
service = LoginService.get(LoginService.name == service_name)
|
||||
try:
|
||||
service = LoginService.get(LoginService.name == service_id)
|
||||
except LoginService.DoesNotExist:
|
||||
service = LoginService.create(name=service_id)
|
||||
|
||||
FederatedLogin.create(user=new_user, service=service,
|
||||
service_ident=service_ident,
|
||||
metadata_json=json.dumps(metadata))
|
||||
|
@ -389,20 +393,20 @@ def create_federated_user(username, email, service_name, service_ident,
|
|||
return new_user
|
||||
|
||||
|
||||
def attach_federated_login(user, service_name, service_ident, metadata={}):
|
||||
service = LoginService.get(LoginService.name == service_name)
|
||||
def attach_federated_login(user, service_id, service_ident, metadata={}):
|
||||
service = LoginService.get(LoginService.name == service_id)
|
||||
FederatedLogin.create(user=user, service=service, service_ident=service_ident,
|
||||
metadata_json=json.dumps(metadata))
|
||||
return user
|
||||
|
||||
|
||||
def verify_federated_login(service_name, service_ident):
|
||||
def verify_federated_login(service_id, service_ident):
|
||||
try:
|
||||
found = (FederatedLogin
|
||||
.select(FederatedLogin, User)
|
||||
.join(LoginService)
|
||||
.switch(FederatedLogin).join(User)
|
||||
.where(FederatedLogin.service_ident == service_ident, LoginService.name == service_name)
|
||||
.where(FederatedLogin.service_ident == service_ident, LoginService.name == service_id)
|
||||
.get())
|
||||
return found.user
|
||||
except FederatedLogin.DoesNotExist:
|
||||
|
|
|
@ -197,7 +197,6 @@ def render_page_template(name, route_data=None, **kwargs):
|
|||
'title': login_service.service_name(),
|
||||
'config': login_service.get_public_config(),
|
||||
'icon': login_service.get_icon(),
|
||||
'scopes': login_service.get_login_scopes(),
|
||||
})
|
||||
|
||||
return login_config
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import logging
|
||||
import uuid
|
||||
|
||||
from flask import request, redirect, url_for, Blueprint
|
||||
from peewee import IntegrityError
|
||||
|
@ -50,6 +51,7 @@ def _conduct_oauth_login(service_id, service_name, user_id, username, email, met
|
|||
|
||||
# Try to create the user
|
||||
try:
|
||||
# Generate a valid username.
|
||||
new_username = None
|
||||
for valid in generate_valid_usernames(username):
|
||||
if model.user.get_user_or_org(valid):
|
||||
|
@ -58,6 +60,11 @@ def _conduct_oauth_login(service_id, service_name, user_id, username, email, met
|
|||
new_username = valid
|
||||
break
|
||||
|
||||
# Generate a valid email. If the email is None and the MAILING feature is turned
|
||||
# off, simply place in a fake email address.
|
||||
if email is None and not features.MAILING:
|
||||
email = '%s@fake.example.com' % (str(uuid.uuid4()))
|
||||
|
||||
prompts = model.user.get_default_user_prompts(features)
|
||||
to_login = model.user.create_federated_user(new_username, email, service_id,
|
||||
user_id, set_password_notification=True,
|
||||
|
@ -102,6 +109,7 @@ def _register_service(login_service):
|
|||
try:
|
||||
lid, lusername, lemail = login_service.exchange_code_for_login(app.config, client, code, '')
|
||||
except OAuthLoginException as ole:
|
||||
logger.exception('Got login exception')
|
||||
return _render_ologin_error(login_service.service_name(), ole.message)
|
||||
|
||||
# Conduct login.
|
||||
|
|
|
@ -22,7 +22,7 @@ EXTERNAL_JS = [
|
|||
]
|
||||
|
||||
EXTERNAL_CSS = [
|
||||
'netdna.bootstrapcdn.com/font-awesome/4.6.0/css/font-awesome.css',
|
||||
'netdna.bootstrapcdn.com/font-awesome/4.7.0/css/font-awesome.css',
|
||||
'netdna.bootstrapcdn.com/bootstrap/3.3.2/css/bootstrap.min.css',
|
||||
'fonts.googleapis.com/css?family=Source+Sans+Pro:300,400,700',
|
||||
's3.amazonaws.com/cdn.core-os.net/icons/core-icons.css',
|
||||
|
|
|
@ -51,7 +51,7 @@ class OAuthService(object):
|
|||
def get_redirect_uri(self, app_config, redirect_suffix=''):
|
||||
return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
|
||||
app_config['SERVER_HOSTNAME'],
|
||||
self.service_name().lower(),
|
||||
self.service_id(),
|
||||
redirect_suffix)
|
||||
|
||||
def get_user_info(self, http_client, token):
|
||||
|
@ -74,7 +74,7 @@ class OAuthService(object):
|
|||
def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
|
||||
redirect_suffix='', client_auth=False):
|
||||
""" Exchanges an OAuth access code for the associated OAuth token. """
|
||||
json_data = self._exchange_code(app_config, http_client, code, form_encode, redirect_suffix,
|
||||
json_data = self.exchange_code(app_config, http_client, code, form_encode, redirect_suffix,
|
||||
client_auth)
|
||||
|
||||
access_token = json_data.get('access_token', None)
|
||||
|
@ -84,8 +84,9 @@ class OAuthService(object):
|
|||
|
||||
return access_token
|
||||
|
||||
def _exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='',
|
||||
def exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='',
|
||||
client_auth=False):
|
||||
""" Exchanges an OAuth access code for associated OAuth token and other data. """
|
||||
payload = {
|
||||
'code': code,
|
||||
'grant_type': 'authorization_code',
|
||||
|
|
|
@ -14,6 +14,10 @@ class OAuthLoginService(OAuthService):
|
|||
""" A base class for defining an OAuth-compliant service that can be used for, amongst other
|
||||
things, login and authentication. """
|
||||
|
||||
def login_enabled(self):
|
||||
""" Returns true if the login service is enabled. """
|
||||
raise NotImplementedError
|
||||
|
||||
def get_login_service_id(self, user_info):
|
||||
""" Returns the internal ID for the given user under this login service. """
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,7 +1,11 @@
|
|||
import features
|
||||
|
||||
from oauth.services.github import GithubOAuthService
|
||||
from oauth.services.google import GoogleOAuthService
|
||||
from oauth.oidc import OIDCLoginService
|
||||
|
||||
CUSTOM_LOGIN_SERVICES = {
|
||||
'GITHUB_LOGIN_CONFIG': GithubOAuthService,
|
||||
'GOOGLE_LOGIN_CONFIG': GoogleOAuthService,
|
||||
}
|
||||
|
||||
class OAuthLoginManager(object):
|
||||
""" Helper class which manages all registered OAuth login services. """
|
||||
|
@ -9,11 +13,12 @@ class OAuthLoginManager(object):
|
|||
self.services = []
|
||||
|
||||
# Register the endpoints for each of the OAuth login services.
|
||||
# TODO(jschorr): make this dynamic.
|
||||
if config.get('GITHUB_LOGIN_CONFIG') is not None and features.GITHUB_LOGIN:
|
||||
github_service = GithubOAuthService(config, 'GITHUB_LOGIN_CONFIG')
|
||||
self.services.append(github_service)
|
||||
|
||||
if config.get('GOOGLE_LOGIN_CONFIG') is not None and features.GOOGLE_LOGIN:
|
||||
google_service = GoogleOAuthService(config, 'GOOGLE_LOGIN_CONFIG')
|
||||
self.services.append(google_service)
|
||||
for key in config.keys():
|
||||
# All keys which end in _LOGIN_CONFIG setup a login service.
|
||||
if key.endswith('_LOGIN_CONFIG'):
|
||||
if key in CUSTOM_LOGIN_SERVICES:
|
||||
custom_service = CUSTOM_LOGIN_SERVICES[key](config, key)
|
||||
if custom_service.login_enabled(config):
|
||||
self.services.append(custom_service)
|
||||
else:
|
||||
self.services.append(OIDCLoginService(config, key))
|
||||
|
|
242
oauth/oidc.py
242
oauth/oidc.py
|
@ -3,108 +3,244 @@ import json
|
|||
import logging
|
||||
import urlparse
|
||||
|
||||
import jwt
|
||||
|
||||
from cachetools import lru_cache
|
||||
from cachetools.ttl import TTLCache
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.serialization import load_der_public_key
|
||||
from jwkest.jwk import KEYS
|
||||
|
||||
from util.oauth.base import OAuthService
|
||||
from oauth.base import OAuthService, OAuthExchangeCodeException, OAuthGetUserInfoException
|
||||
from oauth.login import OAuthLoginException
|
||||
from util.security.jwtutil import decode, InvalidTokenError
|
||||
from util import get_app_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def decode_user_jwt(token, oidc_provider):
|
||||
try:
|
||||
return decode(token, oidc_provider.get_public_key(), algorithms=['RS256'],
|
||||
audience=oidc_provider.client_id(),
|
||||
issuer=oidc_provider.issuer)
|
||||
except InvalidTokenError:
|
||||
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
|
||||
return decode(token, oidc_provider.get_public_key(force_refresh=True), algorithms=['RS256'],
|
||||
audience=oidc_provider.client_id(),
|
||||
issuer=oidc_provider.issuer)
|
||||
|
||||
OIDC_WELLKNOWN = ".well-known/openid-configuration"
|
||||
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour
|
||||
ALLOWED_ALGORITHMS = ['RS256']
|
||||
JWT_CLOCK_SKEW_SECONDS = 30
|
||||
|
||||
class OIDCConfig(OAuthService):
|
||||
class DiscoveryFailureException(Exception):
|
||||
""" Exception raised when OIDC discovery fails. """
|
||||
pass
|
||||
|
||||
|
||||
class PublicKeyLoadException(Exception):
|
||||
""" Exception raised if loading the OIDC public key fails. """
|
||||
pass
|
||||
|
||||
|
||||
class OIDCLoginService(OAuthService):
|
||||
""" Defines a generic service for all OpenID-connect compatible login services. """
|
||||
def __init__(self, config, key_name):
|
||||
super(OIDCConfig, self).__init__(config, key_name)
|
||||
super(OIDCLoginService, self).__init__(config, key_name)
|
||||
|
||||
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
|
||||
self._config = config
|
||||
self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._load_public_key)
|
||||
self._id = key_name[0:key_name.find('_')].lower()
|
||||
self._http_client = config['HTTPCLIENT']
|
||||
self._mailing = config.get('FEATURE_MAILING', False)
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _oidc_config(self):
|
||||
if self.config.get('OIDC_SERVER'):
|
||||
return self._load_via_discovery(self._config.get('DEBUGGING', False))
|
||||
else:
|
||||
return {}
|
||||
def service_id(self):
|
||||
return self._id
|
||||
|
||||
def _load_via_discovery(self, is_debugging):
|
||||
oidc_server = self.config['OIDC_SERVER']
|
||||
if not oidc_server.startswith('https://') and not is_debugging:
|
||||
raise Exception('OIDC server must be accessed over SSL')
|
||||
def service_name(self):
|
||||
return self.config.get('SERVICE_NAME', self.service_id())
|
||||
|
||||
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
|
||||
discovery = self._http_client.get(discovery_url, timeout=5)
|
||||
def get_icon(self):
|
||||
return self.config.get('SERVICE_ICON', 'fa-user-circle')
|
||||
|
||||
if discovery.status_code / 100 != 2:
|
||||
raise Exception("Could not load OIDC discovery information")
|
||||
def get_login_scopes(self):
|
||||
default_scopes = ['openid']
|
||||
|
||||
try:
|
||||
return json.loads(discovery.text)
|
||||
except ValueError:
|
||||
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
||||
raise Exception("Could not parse OIDC discovery information")
|
||||
if self.user_endpoint() is not None:
|
||||
default_scopes.append('profile')
|
||||
|
||||
if self._mailing:
|
||||
default_scopes.append('email')
|
||||
|
||||
return self._oidc_config().get('scopes_supported', default_scopes)
|
||||
|
||||
def authorize_endpoint(self):
|
||||
return self._oidc_config().get('authorization_endpoint', '') + '?'
|
||||
return self._oidc_config().get('authorization_endpoint', '') + '?response_type=code&'
|
||||
|
||||
def token_endpoint(self):
|
||||
return self._oidc_config().get('token_endpoint')
|
||||
|
||||
def user_endpoint(self):
|
||||
return None
|
||||
return self._oidc_config().get('userinfo_endpoint')
|
||||
|
||||
def validate_client_id_and_secret(self, http_client, app_config):
|
||||
pass
|
||||
# TODO: find a way to verify client secret too.
|
||||
redirect_url = '%s/oauth2/%s/callback' % (get_app_url(app_config), self.service_id())
|
||||
scopes_string = ' '.join(self.get_login_scopes())
|
||||
authorize_url = '%sclient_id=%s&redirect_uri=%s&scope=%s' % (self.authorize_endpoint(),
|
||||
self.client_id(),
|
||||
redirect_url,
|
||||
scopes_string)
|
||||
|
||||
check_auth_url = http_client.get(authorize_url)
|
||||
if check_auth_url.status_code // 100 != 2:
|
||||
raise Exception('Got non-200 status code for authorization endpoint')
|
||||
|
||||
def requires_form_encoding(self):
|
||||
return True
|
||||
|
||||
def get_public_config(self):
|
||||
return {
|
||||
'CLIENT_ID': self.client_id(),
|
||||
'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
|
||||
'OIDC': True,
|
||||
}
|
||||
|
||||
def exchange_code_for_login(self, app_config, http_client, code, redirect_suffix):
|
||||
# Exchange the code for the access token and id_token
|
||||
try:
|
||||
json_data = self.exchange_code(app_config, http_client, code,
|
||||
redirect_suffix=redirect_suffix,
|
||||
form_encode=self.requires_form_encoding())
|
||||
except OAuthExchangeCodeException as oce:
|
||||
raise OAuthLoginException(oce.message)
|
||||
|
||||
# Make sure we received both.
|
||||
access_token = json_data.get('access_token', None)
|
||||
if access_token is None:
|
||||
logger.debug('Missing access_token in response: %s', json_data)
|
||||
raise OAuthLoginException('Missing `access_token` in OIDC response')
|
||||
|
||||
id_token = json_data.get('id_token', None)
|
||||
if id_token is None:
|
||||
logger.debug('Missing id_token in response: %s', json_data)
|
||||
raise OAuthLoginException('Missing `id_token` in OIDC response')
|
||||
|
||||
# Decode the id_token.
|
||||
try:
|
||||
decoded_id_token = self._decode_user_jwt(id_token)
|
||||
except InvalidTokenError as ite:
|
||||
logger.exception('Got invalid token error on OIDC decode: %s', ite.message)
|
||||
raise OAuthLoginException('Could not decode OIDC token')
|
||||
except PublicKeyLoadException as pke:
|
||||
logger.exception('Could not load public key during OIDC decode: %s', pke.message)
|
||||
raise OAuthLoginException('Could find public OIDC key')
|
||||
|
||||
# Retrieve the user information.
|
||||
try:
|
||||
user_info = self.get_user_info(http_client, access_token)
|
||||
except OAuthGetUserInfoException as oge:
|
||||
raise OAuthLoginException(oge.message)
|
||||
|
||||
# Verify subs.
|
||||
if user_info['sub'] != decoded_id_token['sub']:
|
||||
raise OAuthLoginException('Mismatch in `sub` returned by OIDC user info endpoint')
|
||||
|
||||
# Check if we have a verified email address.
|
||||
email_address = user_info.get('email') if user_info.get('email_verified') else None
|
||||
if self._mailing:
|
||||
if email_address is None:
|
||||
raise OAuthLoginException('A verified email address is required to login with this service')
|
||||
|
||||
# Check for a preferred username.
|
||||
lusername = user_info.get('preferred_username') or user_info.get('sub')
|
||||
return decoded_id_token['sub'], lusername, email_address
|
||||
|
||||
@property
|
||||
def issuer(self):
|
||||
def _issuer(self):
|
||||
return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])
|
||||
|
||||
def get_public_key(self, force_refresh=False):
|
||||
""" Retrieves the public key for this handler. """
|
||||
@lru_cache(maxsize=1)
|
||||
def _oidc_config(self):
|
||||
if self.config.get('OIDC_SERVER'):
|
||||
return self._load_oidc_config_via_discovery(self.config.get('DEBUGGING', False))
|
||||
else:
|
||||
return {}
|
||||
|
||||
def _load_oidc_config_via_discovery(self, is_debugging):
|
||||
""" Attempts to load the OIDC config via the OIDC discovery mechanism. If is_debugging is True,
|
||||
non-secure connections are alllowed. Raises an DiscoveryFailureException on failure.
|
||||
"""
|
||||
oidc_server = self.config['OIDC_SERVER']
|
||||
if not oidc_server.startswith('https://') and not is_debugging:
|
||||
raise DiscoveryFailureException('OIDC server must be accessed over SSL')
|
||||
|
||||
discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
|
||||
discovery = self._http_client.get(discovery_url, timeout=5, verify=not is_debugging)
|
||||
if discovery.status_code // 100 != 2:
|
||||
logger.debug('Got %s response for OIDC discovery: %s', discovery.status_code, discovery.text)
|
||||
raise DiscoveryFailureException("Could not load OIDC discovery information")
|
||||
|
||||
try:
|
||||
return json.loads(discovery.text)
|
||||
except ValueError:
|
||||
logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
|
||||
raise DiscoveryFailureException("Could not parse OIDC discovery information")
|
||||
|
||||
def _decode_user_jwt(self, token):
|
||||
""" Decodes the given JWT under the given provider and returns it. Raises an InvalidTokenError
|
||||
exception on an invalid token or a PublicKeyLoadException if the public key could not be
|
||||
loaded for decoding.
|
||||
"""
|
||||
# Find the key to use.
|
||||
headers = jwt.get_unverified_header(token)
|
||||
kid = headers.get('kid', None)
|
||||
if kid is None:
|
||||
raise InvalidTokenError('Missing `kid` header')
|
||||
|
||||
try:
|
||||
return decode(token, self._get_public_key(kid), algorithms=ALLOWED_ALGORITHMS,
|
||||
audience=self.client_id(),
|
||||
issuer=self._issuer,
|
||||
leeway=JWT_CLOCK_SKEW_SECONDS,
|
||||
options=dict(require_nbf=False))
|
||||
except InvalidTokenError:
|
||||
# Public key may have expired. Try to retrieve an updated public key and use it to decode.
|
||||
return decode(token, self._get_public_key(kid, force_refresh=True),
|
||||
algorithms=ALLOWED_ALGORITHMS,
|
||||
audience=self.client_id(),
|
||||
issuer=self._issuer,
|
||||
leeway=JWT_CLOCK_SKEW_SECONDS,
|
||||
options=dict(require_nbf=False))
|
||||
|
||||
def _get_public_key(self, kid, force_refresh=False):
|
||||
""" Retrieves the public key for this handler with the given kid. Raises a
|
||||
PublicKeyLoadException on failure. """
|
||||
|
||||
# If force_refresh is true, we expire all the items in the cache by setting the time to
|
||||
# the current time + the expiration TTL.
|
||||
if force_refresh:
|
||||
self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)
|
||||
|
||||
# Retrieve the public key from the cache. If the cache does not contain the public key,
|
||||
# it will internally call _get_public_key to retrieve it and then save it. The None is
|
||||
# a random key chose to be stored in the cache, and could be anything.
|
||||
return self._public_key_cache[None]
|
||||
# it will internally call _load_public_key to retrieve it and then save it.
|
||||
return self._public_key_cache[kid]
|
||||
|
||||
def _get_public_key(self, _):
|
||||
""" Retrieves the public key for this handler. """
|
||||
def _load_public_key(self, kid):
|
||||
""" Loads the public key for this handler from the OIDC service. Raises PublicKeyLoadException
|
||||
on failure.
|
||||
"""
|
||||
keys_url = self._oidc_config()['jwks_uri']
|
||||
|
||||
# Load the keys.
|
||||
try:
|
||||
keys = KEYS()
|
||||
keys.load_from_url(keys_url)
|
||||
keys.load_from_url(keys_url, verify=not self.config.get('DEBUGGING', False))
|
||||
except Exception as ex:
|
||||
logger.exception('Exception loading public key')
|
||||
raise PublicKeyLoadException(ex.message)
|
||||
|
||||
if not list(keys):
|
||||
raise Exception('No keys provided by OIDC provider')
|
||||
# Find the matching key.
|
||||
keys_found = keys.by_kid(kid)
|
||||
if len(keys_found) == 0:
|
||||
raise PublicKeyLoadException('Public key %s not found' % kid)
|
||||
|
||||
rsa_key = list(keys)[0]
|
||||
rsa_key.deserialize()
|
||||
rsa_keys = [key for key in keys_found if key.kty == 'RSA']
|
||||
if len(rsa_keys) == 0:
|
||||
raise PublicKeyLoadException('No RSA form of public key %s not found' % kid)
|
||||
|
||||
matching_key = rsa_keys[0]
|
||||
matching_key.deserialize()
|
||||
|
||||
# Reload the key so that we can give a key *instance* to PyJWT to work around its weird parsing
|
||||
# issues.
|
||||
return load_der_public_key(rsa_key.key.exportKey('DER'), backend=default_backend())
|
||||
return load_der_public_key(matching_key.key.exportKey('DER'), backend=default_backend())
|
||||
|
|
|
@ -9,6 +9,9 @@ class GithubOAuthService(OAuthLoginService):
|
|||
def __init__(self, config, key_name):
|
||||
super(GithubOAuthService, self).__init__(config, key_name)
|
||||
|
||||
def login_enabled(self, config):
|
||||
return config.get('FEATURE_GITHUB_LOGIN', False)
|
||||
|
||||
def service_id(self):
|
||||
return 'github'
|
||||
|
||||
|
|
|
@ -12,6 +12,9 @@ class GoogleOAuthService(OAuthLoginService):
|
|||
def __init__(self, config, key_name):
|
||||
super(GoogleOAuthService, self).__init__(config, key_name)
|
||||
|
||||
def login_enabled(self, config):
|
||||
return config.get('FEATURE_GOOGLE_LOGIN', False)
|
||||
|
||||
def service_id(self):
|
||||
return 'google'
|
||||
|
||||
|
|
62
oauth/test/test_loginmanager.py
Normal file
62
oauth/test/test_loginmanager.py
Normal file
|
@ -0,0 +1,62 @@
|
|||
from oauth.loginmanager import OAuthLoginManager
|
||||
from oauth.services.github import GithubOAuthService
|
||||
from oauth.services.google import GoogleOAuthService
|
||||
from oauth.oidc import OIDCLoginService
|
||||
|
||||
def test_login_manager_github():
|
||||
config = {
|
||||
'FEATURE_GITHUB_LOGIN': True,
|
||||
'GITHUB_LOGIN_CONFIG': {},
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 1
|
||||
assert isinstance(loginmanager.services[0], GithubOAuthService)
|
||||
|
||||
def test_github_disabled():
|
||||
config = {
|
||||
'GITHUB_LOGIN_CONFIG': {},
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 0
|
||||
|
||||
def test_login_manager_google():
|
||||
config = {
|
||||
'FEATURE_GOOGLE_LOGIN': True,
|
||||
'GOOGLE_LOGIN_CONFIG': {},
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 1
|
||||
assert isinstance(loginmanager.services[0], GoogleOAuthService)
|
||||
|
||||
def test_google_disabled():
|
||||
config = {
|
||||
'GOOGLE_LOGIN_CONFIG': {},
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 0
|
||||
|
||||
def test_oidc():
|
||||
config = {
|
||||
'SOMECOOL_LOGIN_CONFIG': {},
|
||||
'HTTPCLIENT': None,
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 1
|
||||
assert isinstance(loginmanager.services[0], OIDCLoginService)
|
||||
|
||||
def test_multiple_oidc():
|
||||
config = {
|
||||
'SOMECOOL_LOGIN_CONFIG': {},
|
||||
'ANOTHER_LOGIN_CONFIG': {},
|
||||
'HTTPCLIENT': None,
|
||||
}
|
||||
|
||||
loginmanager = OAuthLoginManager(config)
|
||||
assert len(loginmanager.services) == 2
|
||||
assert isinstance(loginmanager.services[0], OIDCLoginService)
|
||||
assert isinstance(loginmanager.services[1], OIDCLoginService)
|
273
oauth/test/test_oidc.py
Normal file
273
oauth/test/test_oidc.py
Normal file
|
@ -0,0 +1,273 @@
|
|||
# pylint: disable=redefined-outer-name, unused-argument, C0103, C0111, too-many-arguments
|
||||
|
||||
import json
|
||||
import time
|
||||
import urlparse
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from httmock import urlmatch, HTTMock
|
||||
from Crypto.PublicKey import RSA
|
||||
from jwkest.jwk import RSAKey
|
||||
|
||||
from oauth.oidc import OIDCLoginService, OAuthLoginException
|
||||
|
||||
@pytest.fixture()
|
||||
def http_client():
|
||||
sess = requests.Session()
|
||||
adapter = requests.adapters.HTTPAdapter(pool_connections=100,
|
||||
pool_maxsize=100)
|
||||
sess.mount('http://', adapter)
|
||||
sess.mount('https://', adapter)
|
||||
return sess
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def app_config(http_client, request):
|
||||
return {
|
||||
'PREFERRED_URL_SCHEME': 'http',
|
||||
'SERVER_HOSTNAME': 'localhost',
|
||||
'FEATURE_MAILING': request.param,
|
||||
|
||||
'SOMEOIDC_TEST_SERVICE': {
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
'SERVICE_NAME': 'Some Cool Service',
|
||||
'SERVICE_ICON': 'http://some/icon',
|
||||
'OIDC_SERVER': 'http://fakeoidc',
|
||||
'DEBUGGING': True,
|
||||
},
|
||||
|
||||
'HTTPCLIENT': http_client,
|
||||
}
|
||||
|
||||
@pytest.fixture()
|
||||
def oidc_service(app_config):
|
||||
return OIDCLoginService(app_config, 'SOMEOIDC_TEST_SERVICE')
|
||||
|
||||
@pytest.fixture()
|
||||
def discovery_content():
|
||||
return {
|
||||
'scopes_supported': ['profile'],
|
||||
'authorization_endpoint': 'http://fakeoidc/authorize',
|
||||
'token_endpoint': 'http://fakeoidc/token',
|
||||
'userinfo_endpoint': 'http://fakeoidc/userinfo',
|
||||
'jwks_uri': 'http://fakeoidc/jwks',
|
||||
}
|
||||
|
||||
@pytest.fixture()
|
||||
def discovery_handler(discovery_content):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'.+openid.+')
|
||||
def handler(_, __):
|
||||
return json.dumps(discovery_content)
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture(scope="module") # Slow to generate, only do it once.
|
||||
def signing_key():
|
||||
private_key = RSA.generate(2048)
|
||||
jwk = RSAKey(key=private_key.publickey()).serialize()
|
||||
return {
|
||||
'id': 'somekey',
|
||||
'private_key': private_key.exportKey('PEM'),
|
||||
'jwk': jwk,
|
||||
}
|
||||
|
||||
@pytest.fixture()
|
||||
def id_token(oidc_service, signing_key, app_config):
|
||||
token_data = {
|
||||
'iss': oidc_service.config['OIDC_SERVER'],
|
||||
'aud': oidc_service.client_id(),
|
||||
'nbf': int(time.time()),
|
||||
'iat': int(time.time()),
|
||||
'exp': int(time.time() + 600),
|
||||
'sub': 'cooluser',
|
||||
}
|
||||
|
||||
token_headers = {
|
||||
'kid': signing_key['id'],
|
||||
}
|
||||
|
||||
return jwt.encode(token_data, signing_key['private_key'], 'RS256', headers=token_headers)
|
||||
|
||||
@pytest.fixture()
|
||||
def valid_code():
|
||||
return 'validcode'
|
||||
|
||||
@pytest.fixture()
|
||||
def token_handler(oidc_service, id_token, valid_code):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/token')
|
||||
def handler(_, request):
|
||||
params = urlparse.parse_qs(request.body)
|
||||
if params.get('redirect_uri')[0] != 'http://localhost/oauth2/someoidc/callback':
|
||||
return {'status_code': 400, 'content': 'Invalid redirect URI'}
|
||||
|
||||
if params.get('client_id')[0] != oidc_service.client_id():
|
||||
return {'status_code': 401, 'content': 'Invalid client id'}
|
||||
|
||||
if params.get('client_secret')[0] != oidc_service.client_secret():
|
||||
return {'status_code': 401, 'content': 'Invalid client secret'}
|
||||
|
||||
if params.get('code')[0] != valid_code:
|
||||
return {'status_code': 401, 'content': 'Invalid code'}
|
||||
|
||||
if params.get('grant_type')[0] != 'authorization_code':
|
||||
return {'status_code': 400, 'content': 'Invalid authorization type'}
|
||||
|
||||
content = {
|
||||
'access_token': 'sometoken',
|
||||
'id_token': id_token,
|
||||
}
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def jwks_handler(signing_key):
|
||||
def jwk_with_kid(kid, jwk):
|
||||
jwk = jwk.copy()
|
||||
jwk.update({'kid': kid})
|
||||
return jwk
|
||||
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/jwks')
|
||||
def handler(_, __):
|
||||
content = {'keys': [jwk_with_kid(signing_key['id'], signing_key['jwk'])]}
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def emptykeys_jwks_handler():
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/jwks')
|
||||
def handler(_, __):
|
||||
content = {'keys': []}
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture(params=["someusername", None])
|
||||
def preferred_username(request):
|
||||
return request.param
|
||||
|
||||
@pytest.fixture
|
||||
def userinfo_handler(oidc_service, preferred_username):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
|
||||
def handler(_, __):
|
||||
content = {
|
||||
'sub': 'cooluser',
|
||||
'preferred_username':preferred_username,
|
||||
'email': 'foo@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def invalidsub_userinfo_handler(oidc_service):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
|
||||
def handler(_, __):
|
||||
content = {
|
||||
'sub': 'invalidsub',
|
||||
'preferred_username': 'someusername',
|
||||
'email': 'foo@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
@pytest.fixture()
|
||||
def missingemail_userinfo_handler(oidc_service, preferred_username):
|
||||
@urlmatch(netloc=r'fakeoidc', path=r'/userinfo')
|
||||
def handler(_, __):
|
||||
content = {
|
||||
'sub': 'cooluser',
|
||||
'preferred_username': preferred_username,
|
||||
}
|
||||
|
||||
return {'status_code': 200, 'content': json.dumps(content)}
|
||||
|
||||
return handler
|
||||
|
||||
def test_basic_config(oidc_service):
|
||||
assert oidc_service.service_id() == 'someoidc'
|
||||
assert oidc_service.service_name() == 'Some Cool Service'
|
||||
assert oidc_service.get_icon() == 'http://some/icon'
|
||||
|
||||
def test_discovery(oidc_service, http_client, discovery_handler):
|
||||
with HTTMock(discovery_handler):
|
||||
assert oidc_service.authorize_endpoint() == 'http://fakeoidc/authorize?response_type=code&'
|
||||
assert oidc_service.token_endpoint() == 'http://fakeoidc/token'
|
||||
assert oidc_service.user_endpoint() == 'http://fakeoidc/userinfo'
|
||||
assert oidc_service.get_login_scopes() == ['profile']
|
||||
|
||||
def test_public_config(oidc_service, discovery_handler):
|
||||
with HTTMock(discovery_handler):
|
||||
assert oidc_service.get_public_config()['OIDC']
|
||||
assert oidc_service.get_public_config()['CLIENT_ID'] == 'foo'
|
||||
|
||||
assert 'CLIENT_SECRET' not in oidc_service.get_public_config()
|
||||
assert 'bar' not in oidc_service.get_public_config().values()
|
||||
|
||||
def test_exchange_code_invalidcode(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler):
|
||||
with HTTMock(token_handler, discovery_handler):
|
||||
with pytest.raises(OAuthLoginException):
|
||||
oidc_service.exchange_code_for_login(app_config, http_client, 'testcode', '')
|
||||
|
||||
def test_exchange_code_validcode(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler, userinfo_handler, jwks_handler, valid_code,
|
||||
preferred_username):
|
||||
with HTTMock(jwks_handler, token_handler, userinfo_handler, discovery_handler):
|
||||
lid, lusername, lemail = oidc_service.exchange_code_for_login(app_config, http_client,
|
||||
valid_code, '')
|
||||
|
||||
assert lid == 'cooluser'
|
||||
assert lemail == 'foo@example.com'
|
||||
|
||||
if preferred_username is not None:
|
||||
assert lusername == preferred_username
|
||||
else:
|
||||
assert lusername == lid
|
||||
|
||||
def test_exchange_code_missingemail(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler, missingemail_userinfo_handler, jwks_handler,
|
||||
valid_code, preferred_username):
|
||||
with HTTMock(jwks_handler, token_handler, missingemail_userinfo_handler, discovery_handler):
|
||||
if app_config['FEATURE_MAILING']:
|
||||
# Should fail because there is no valid email address.
|
||||
with pytest.raises(OAuthLoginException):
|
||||
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
|
||||
else:
|
||||
# Should succeed because, while there is no valid email address, it isn't necessary with
|
||||
# mailing disabled.
|
||||
lid, lusername, lemail = oidc_service.exchange_code_for_login(app_config, http_client,
|
||||
valid_code, '')
|
||||
|
||||
assert lid == 'cooluser'
|
||||
assert lemail is None
|
||||
|
||||
if preferred_username is not None:
|
||||
assert lusername == preferred_username
|
||||
else:
|
||||
assert lusername == lid
|
||||
|
||||
def test_exchange_code_invalidsub(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler, invalidsub_userinfo_handler, jwks_handler,
|
||||
valid_code):
|
||||
with HTTMock(jwks_handler, token_handler, invalidsub_userinfo_handler, discovery_handler):
|
||||
# Should fail because the sub of the user info doesn't match that returned by the id_token.
|
||||
with pytest.raises(OAuthLoginException):
|
||||
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
|
||||
|
||||
def test_exchange_code_missingkey(oidc_service, discovery_handler, app_config, http_client,
|
||||
token_handler, userinfo_handler, emptykeys_jwks_handler,
|
||||
valid_code):
|
||||
with HTTMock(emptykeys_jwks_handler, token_handler, userinfo_handler, discovery_handler):
|
||||
# Should fail because the key is missing.
|
||||
with pytest.raises(OAuthLoginException):
|
||||
oidc_service.exchange_code_for_login(app_config, http_client, valid_code, '')
|
|
@ -8,7 +8,6 @@ import base64
|
|||
from urllib import urlencode
|
||||
from urlparse import urlparse, urlunparse, parse_qs
|
||||
from datetime import datetime, timedelta
|
||||
from httmock import urlmatch, HTTMock
|
||||
|
||||
import jwt
|
||||
|
||||
|
@ -16,7 +15,7 @@ from Crypto.PublicKey import RSA
|
|||
from flask import url_for
|
||||
from jwkest.jwk import RSAKey
|
||||
|
||||
from app import app, oauth_login
|
||||
from app import app
|
||||
from data import model
|
||||
from data.database import ServiceKeyApprovalType
|
||||
from endpoints import keyserver
|
||||
|
@ -25,16 +24,9 @@ from endpoints.api.user import Signin
|
|||
from endpoints.keyserver import jwk_with_kid
|
||||
from endpoints.csrf import OAUTH_CSRF_TOKEN_NAME
|
||||
from endpoints.web import web as web_bp
|
||||
from endpoints.oauthlogin import oauthlogin as oauthlogin_bp
|
||||
from initdb import setup_database_for_testing, finished_database_for_testing
|
||||
from test.helpers import assert_action_logged
|
||||
|
||||
try:
|
||||
app.register_blueprint(oauthlogin_bp, url_prefix='/oauth2')
|
||||
except ValueError:
|
||||
# This blueprint was already registered
|
||||
pass
|
||||
|
||||
try:
|
||||
app.register_blueprint(web_bp, url_prefix='')
|
||||
except ValueError:
|
||||
|
@ -129,96 +121,6 @@ class EndpointTestCase(unittest.TestCase):
|
|||
self.assertEquals(rv.status_code, 200)
|
||||
|
||||
|
||||
class OAuthLoginTestCase(EndpointTestCase):
|
||||
def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident,
|
||||
new_username):
|
||||
# Test callback.
|
||||
created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username)
|
||||
|
||||
# Delete the created user.
|
||||
model.user.delete_user(created, [])
|
||||
|
||||
# Test attach.
|
||||
self.login('devtable', 'password')
|
||||
self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable')
|
||||
|
||||
def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username):
|
||||
# No CSRF.
|
||||
self.getResponse('oauthlogin.' + endpoint_name, expected_code=403)
|
||||
|
||||
# Invalid CSRF.
|
||||
self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403)
|
||||
|
||||
# Valid CSRF, invalid code.
|
||||
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
|
||||
code='invalidcode', expected_code=400)
|
||||
|
||||
# Valid CSRF, valid code.
|
||||
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
|
||||
code='somecode', expected_code=302)
|
||||
|
||||
# Ensure the user was added/modified.
|
||||
found_user = model.user.get_user(username)
|
||||
self.assertIsNotNone(found_user)
|
||||
|
||||
federated_login = model.user.lookup_federated_login(found_user, service_name)
|
||||
self.assertIsNotNone(federated_login)
|
||||
self.assertEquals(federated_login.service_ident, service_ident)
|
||||
return found_user
|
||||
|
||||
def test_google_oauth(self):
|
||||
@urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token')
|
||||
def account_handler(_, request):
|
||||
if request.body.find("code=somecode") > 0:
|
||||
content = {'access_token': 'someaccesstoken'}
|
||||
return py_json.dumps(content)
|
||||
else:
|
||||
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
||||
|
||||
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo')
|
||||
def user_handler(_, __):
|
||||
content = {
|
||||
'id': 'someid',
|
||||
'email': 'someemail@example.com',
|
||||
'verified_email': True,
|
||||
}
|
||||
return py_json.dumps(content)
|
||||
|
||||
with HTTMock(account_handler, user_handler):
|
||||
self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google',
|
||||
'someid', 'someemail')
|
||||
|
||||
def test_github_oauth(self):
|
||||
@urlmatch(netloc=r'github.com', path='/login/oauth/access_token')
|
||||
def account_handler(url, _):
|
||||
if url.query.find("code=somecode") > 0:
|
||||
content = {'access_token': 'someaccesstoken'}
|
||||
return py_json.dumps(content)
|
||||
else:
|
||||
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
||||
|
||||
@urlmatch(netloc=r'github.com', path='/api/v3/user')
|
||||
def user_handler(_, __):
|
||||
content = {
|
||||
'id': 'someid',
|
||||
'login': 'someusername'
|
||||
}
|
||||
return py_json.dumps(content)
|
||||
|
||||
@urlmatch(netloc=r'github.com', path='/api/v3/user/emails')
|
||||
def email_handler(_, __):
|
||||
content = [{
|
||||
'email': 'someemail@example.com',
|
||||
'verified': True,
|
||||
'primary': True,
|
||||
}]
|
||||
return py_json.dumps(content)
|
||||
|
||||
with HTTMock(account_handler, email_handler, user_handler):
|
||||
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
|
||||
'someid', 'someusername')
|
||||
|
||||
|
||||
class WebEndpointTestCase(EndpointTestCase):
|
||||
def test_index(self):
|
||||
self.getResponse('web.index')
|
||||
|
|
175
test/test_oauth_login.py
Normal file
175
test/test_oauth_login.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
import json as py_json
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import jwt
|
||||
|
||||
from Crypto.PublicKey import RSA
|
||||
from httmock import urlmatch, HTTMock
|
||||
from jwkest.jwk import RSAKey
|
||||
|
||||
from app import app
|
||||
from data import model
|
||||
from endpoints.oauthlogin import oauthlogin as oauthlogin_bp
|
||||
from test.test_endpoints import EndpointTestCase
|
||||
|
||||
try:
|
||||
app.register_blueprint(oauthlogin_bp, url_prefix='/oauth2')
|
||||
except ValueError:
|
||||
# This blueprint was already registered
|
||||
pass
|
||||
|
||||
class OAuthLoginTestCase(EndpointTestCase):
|
||||
def invoke_oauth_tests(self, callback_endpoint, attach_endpoint, service_name, service_ident,
|
||||
new_username):
|
||||
# Test callback.
|
||||
created = self.invoke_oauth_test(callback_endpoint, service_name, service_ident, new_username)
|
||||
|
||||
# Delete the created user.
|
||||
model.user.delete_user(created, [])
|
||||
|
||||
# Test attach.
|
||||
self.login('devtable', 'password')
|
||||
self.invoke_oauth_test(attach_endpoint, service_name, service_ident, 'devtable')
|
||||
|
||||
def invoke_oauth_test(self, endpoint_name, service_name, service_ident, username):
|
||||
# No CSRF.
|
||||
self.getResponse('oauthlogin.' + endpoint_name, expected_code=403)
|
||||
|
||||
# Invalid CSRF.
|
||||
self.getResponse('oauthlogin.' + endpoint_name, state='somestate', expected_code=403)
|
||||
|
||||
# Valid CSRF, invalid code.
|
||||
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
|
||||
code='invalidcode', expected_code=400)
|
||||
|
||||
# Valid CSRF, valid code.
|
||||
self.getResponse('oauthlogin.' + endpoint_name, state='someoauthtoken',
|
||||
code='somecode', expected_code=302)
|
||||
|
||||
# Ensure the user was added/modified.
|
||||
found_user = model.user.get_user(username)
|
||||
self.assertIsNotNone(found_user)
|
||||
|
||||
federated_login = model.user.lookup_federated_login(found_user, service_name)
|
||||
self.assertIsNotNone(federated_login)
|
||||
self.assertEquals(federated_login.service_ident, service_ident)
|
||||
return found_user
|
||||
|
||||
def test_google_oauth(self):
|
||||
@urlmatch(netloc=r'accounts.google.com', path='/o/oauth2/token')
|
||||
def account_handler(_, request):
|
||||
if request.body.find("code=somecode") > 0:
|
||||
content = {'access_token': 'someaccesstoken'}
|
||||
return py_json.dumps(content)
|
||||
else:
|
||||
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
||||
|
||||
@urlmatch(netloc=r'www.googleapis.com', path='/oauth2/v1/userinfo')
|
||||
def user_handler(_, __):
|
||||
content = {
|
||||
'id': 'someid',
|
||||
'email': 'someemail@example.com',
|
||||
'verified_email': True,
|
||||
}
|
||||
return py_json.dumps(content)
|
||||
|
||||
with HTTMock(account_handler, user_handler):
|
||||
self.invoke_oauth_tests('google_oauth_callback', 'google_oauth_attach', 'google',
|
||||
'someid', 'someemail')
|
||||
|
||||
def test_github_oauth(self):
|
||||
@urlmatch(netloc=r'github.com', path='/login/oauth/access_token')
|
||||
def account_handler(url, _):
|
||||
if url.query.find("code=somecode") > 0:
|
||||
content = {'access_token': 'someaccesstoken'}
|
||||
return py_json.dumps(content)
|
||||
else:
|
||||
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
||||
|
||||
@urlmatch(netloc=r'github.com', path='/api/v3/user')
|
||||
def user_handler(_, __):
|
||||
content = {
|
||||
'id': 'someid',
|
||||
'login': 'someusername'
|
||||
}
|
||||
return py_json.dumps(content)
|
||||
|
||||
@urlmatch(netloc=r'github.com', path='/api/v3/user/emails')
|
||||
def email_handler(_, __):
|
||||
content = [{
|
||||
'email': 'someemail@example.com',
|
||||
'verified': True,
|
||||
'primary': True,
|
||||
}]
|
||||
return py_json.dumps(content)
|
||||
|
||||
with HTTMock(account_handler, email_handler, user_handler):
|
||||
self.invoke_oauth_tests('github_oauth_callback', 'github_oauth_attach', 'github',
|
||||
'someid', 'someusername')
|
||||
|
||||
def test_oidc_auth(self):
|
||||
private_key = RSA.generate(2048)
|
||||
generatedjwk = RSAKey(key=private_key.publickey()).serialize()
|
||||
kid = 'somekey'
|
||||
private_pem = private_key.exportKey('PEM')
|
||||
|
||||
token_data = {
|
||||
'iss': app.config['TESTOIDC_LOGIN_CONFIG']['OIDC_SERVER'],
|
||||
'aud': app.config['TESTOIDC_LOGIN_CONFIG']['CLIENT_ID'],
|
||||
'nbf': int(time.time()),
|
||||
'iat': int(time.time()),
|
||||
'exp': int(time.time() + 600),
|
||||
'sub': 'cooluser',
|
||||
}
|
||||
|
||||
token_headers = {
|
||||
'kid': kid,
|
||||
}
|
||||
|
||||
id_token = jwt.encode(token_data, private_pem, 'RS256', headers=token_headers)
|
||||
|
||||
@urlmatch(netloc=r'fakeoidc', path='/token')
|
||||
def token_handler(_, request):
|
||||
if request.body.find("code=somecode") >= 0:
|
||||
content = {'access_token': 'someaccesstoken', 'id_token': id_token}
|
||||
return py_json.dumps(content)
|
||||
else:
|
||||
return {'status_code': 400, 'content': '{"message": "Invalid code"}'}
|
||||
|
||||
@urlmatch(netloc=r'fakeoidc', path='/user')
|
||||
def user_handler(_, __):
|
||||
content = {
|
||||
'sub': 'cooluser',
|
||||
'preferred_username': 'someusername',
|
||||
'email': 'someemail@example.com',
|
||||
'email_verified': True,
|
||||
}
|
||||
return py_json.dumps(content)
|
||||
|
||||
@urlmatch(netloc=r'fakeoidc', path='/jwks')
|
||||
def jwks_handler(_, __):
|
||||
jwk = generatedjwk.copy()
|
||||
jwk.update({'kid': kid})
|
||||
|
||||
content = {'keys': [jwk]}
|
||||
return py_json.dumps(content)
|
||||
|
||||
@urlmatch(netloc=r'fakeoidc', path='.+openid.+')
|
||||
def discovery_handler(_, __):
|
||||
content = {
|
||||
'scopes_supported': ['profile'],
|
||||
'authorization_endpoint': 'http://fakeoidc/authorize',
|
||||
'token_endpoint': 'http://fakeoidc/token',
|
||||
'userinfo_endpoint': 'http://fakeoidc/userinfo',
|
||||
'jwks_uri': 'http://fakeoidc/jwks',
|
||||
}
|
||||
return py_json.dumps(content)
|
||||
|
||||
with HTTMock(discovery_handler, jwks_handler, token_handler, user_handler):
|
||||
self.invoke_oauth_tests('testoidc_oauth_callback', 'testoidc_oauth_attach', 'testoidc',
|
||||
'cooluser', 'someusername')
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -81,11 +81,12 @@ class TestConfig(DefaultConfig):
|
|||
|
||||
FEATURE_GITHUB_LOGIN = True
|
||||
FEATURE_GOOGLE_LOGIN = True
|
||||
FEATURE_DEX_LOGIN = True
|
||||
|
||||
DEX_LOGIN_CONFIG = {
|
||||
'CLIENT_ID': 'someclientid',
|
||||
'OIDC_SERVER': 'https://oidcserver/',
|
||||
TESTOIDC_LOGIN_CONFIG = {
|
||||
'CLIENT_ID': 'foo',
|
||||
'CLIENT_SECRET': 'bar',
|
||||
'OIDC_SERVER': 'http://fakeoidc',
|
||||
'DEBUGGING': True,
|
||||
}
|
||||
|
||||
RECAPTCHA_SITE_KEY = 'somekey'
|
||||
|
|
Reference in a new issue