This repository has been archived on 2020-03-24. You can view files and clone it, but cannot push or open issues or pull requests.
quay/oauth/base.py
Joseph Schorr 2c35383724 Allow OAuth and OIDC login engines to bind to fields in internal auth
This feature is subtle but very important: Currently, when a user logs in via an "external" auth system (such as Github), they are either logged into an existing bound account or a new account is created for them in the database. While this normally works jut fine, it hits a roadblock when the *internal* auth system configured is not the database, but instead something like LDAP. In that case, *most* Enterprise customers will prefer that logging in via external auth (like OIDC) will also *automatically* bind the newly created account to the backing *internal* auth account. For example, login via PingFederate OIDC (backed by LDAP) should also bind the new QE account to the associated LDAP account, via either username or email. This change allows for this binding field to be specified, and thereafter will perform the proper lookups and bindings.
2017-02-16 16:27:53 -05:00

166 lines
5.9 KiB
Python

import logging
import urllib
from abc import ABCMeta, abstractmethod
from six import add_metaclass
from util import get_app_url
logger = logging.getLogger(__name__)
class OAuthExchangeCodeException(Exception):
""" Exception raised if a code exchange fails. """
pass
class OAuthGetUserInfoException(Exception):
""" Exception raised if a call to get user information fails. """
pass
@add_metaclass(ABCMeta)
class OAuthService(object):
""" A base class for defining an external service, exposed via OAuth. """
def __init__(self, config, key_name):
self.key_name = key_name
self.config = config.get(key_name) or {}
@abstractmethod
def service_id(self):
""" The internal ID for this service. Must match the URL portion for the service, e.g. `github`
"""
pass
@abstractmethod
def service_name(self):
""" The user-readable name for the service, e.g. `GitHub`"""
pass
@abstractmethod
def token_endpoint(self):
""" The endpoint at which the OAuth code can be exchanged for a token. """
pass
@abstractmethod
def user_endpoint(self):
""" The endpoint at which user information can be looked up. """
pass
@abstractmethod
def validate_client_id_and_secret(self, http_client, app_config):
""" Performs validation of the client ID and secret, raising an exception on failure. """
pass
@abstractmethod
def authorize_endpoint(self):
""" Endpoint for authorization. """
pass
def requires_form_encoding(self):
""" Returns True if form encoding is necessary for the exchange_code_for_token call. """
return False
def client_id(self):
return self.config.get('CLIENT_ID')
def client_secret(self):
return self.config.get('CLIENT_SECRET')
def login_binding_field(self):
""" Returns the name of the field (`username` or `email`) used for auto binding an external
login service account to an *internal* login service account. For example, if the external
login service is GitHub and the internal login service is LDAP, a value of `email` here
will cause login-with-Github to conduct a search (via email) in LDAP for a user, an auto
bind the external and internal users together. May return None, in which case no binding
is performing, and login with this external account will simply create a new account in the
database.
"""
return self.config.get('LOGIN_BINDING_FIELD', None)
def get_auth_url(self, app_config, redirect_suffix, csrf_token, scopes):
""" Retrieves the authorization URL for this login service. """
redirect_uri = '%s/oauth2/%s/callback%s' % (get_app_url(app_config), self.service_id(),
redirect_suffix)
params = {
'client_id': self.client_id(),
'redirect_uri': redirect_uri,
'scope': ' '.join(scopes),
'state': csrf_token,
}
authorize_url = '%s%s' % (self.authorize_endpoint(), urllib.urlencode(params))
return authorize_url
def get_redirect_uri(self, app_config, redirect_suffix=''):
return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
app_config['SERVER_HOSTNAME'],
self.service_id(),
redirect_suffix)
def get_user_info(self, http_client, token):
token_param = {
'access_token': token,
'alt': 'json',
}
got_user = http_client.get(self.user_endpoint(), params=token_param)
if got_user.status_code // 100 != 2:
raise OAuthGetUserInfoException('Non-2XX response code for user_info call: %s' %
got_user.status_code)
user_info = got_user.json()
if user_info is None:
raise OAuthGetUserInfoException()
return user_info
def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
redirect_suffix='', client_auth=False):
""" Exchanges an OAuth access code for the associated OAuth token. """
json_data = self.exchange_code(app_config, http_client, code, form_encode, redirect_suffix,
client_auth)
access_token = json_data.get('access_token', None)
if access_token is None:
logger.debug('Got successful get_access_token response with missing token: %s', json_data)
raise OAuthExchangeCodeException('Missing `access_token` in OAuth response')
return access_token
def exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='',
client_auth=False):
""" Exchanges an OAuth access code for associated OAuth token and other data. """
payload = {
'code': code,
'grant_type': 'authorization_code',
'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix)
}
headers = {
'Accept': 'application/json'
}
auth = None
if client_auth:
auth = (self.client_id(), self.client_secret())
else:
payload['client_id'] = self.client_id()
payload['client_secret'] = self.client_secret()
token_url = self.token_endpoint()
if form_encode:
get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
else:
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
if get_access_token.status_code // 100 != 2:
logger.debug('Got get_access_token response %s', get_access_token.text)
raise OAuthExchangeCodeException('Got non-2XX response for code exchange: %s' %
get_access_token.status_code)
json_data = get_access_token.json()
if not json_data:
raise OAuthExchangeCodeException('Got non-JSON response for code exchange')
if 'error' in json_data:
raise OAuthExchangeCodeException(json_data.get('error_description', json_data['error']))
return json_data