2c35383724
This feature is subtle but very important: Currently, when a user logs in via an "external" auth system (such as Github), they are either logged into an existing bound account or a new account is created for them in the database. While this normally works jut fine, it hits a roadblock when the *internal* auth system configured is not the database, but instead something like LDAP. In that case, *most* Enterprise customers will prefer that logging in via external auth (like OIDC) will also *automatically* bind the newly created account to the backing *internal* auth account. For example, login via PingFederate OIDC (backed by LDAP) should also bind the new QE account to the associated LDAP account, via either username or email. This change allows for this binding field to be specified, and thereafter will perform the proper lookups and bindings.
166 lines
5.9 KiB
Python
166 lines
5.9 KiB
Python
import logging
|
|
import urllib
|
|
|
|
from abc import ABCMeta, abstractmethod
|
|
from six import add_metaclass
|
|
|
|
from util import get_app_url
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class OAuthExchangeCodeException(Exception):
|
|
""" Exception raised if a code exchange fails. """
|
|
pass
|
|
|
|
class OAuthGetUserInfoException(Exception):
|
|
""" Exception raised if a call to get user information fails. """
|
|
pass
|
|
|
|
@add_metaclass(ABCMeta)
|
|
class OAuthService(object):
|
|
""" A base class for defining an external service, exposed via OAuth. """
|
|
def __init__(self, config, key_name):
|
|
self.key_name = key_name
|
|
self.config = config.get(key_name) or {}
|
|
|
|
@abstractmethod
|
|
def service_id(self):
|
|
""" The internal ID for this service. Must match the URL portion for the service, e.g. `github`
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def service_name(self):
|
|
""" The user-readable name for the service, e.g. `GitHub`"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def token_endpoint(self):
|
|
""" The endpoint at which the OAuth code can be exchanged for a token. """
|
|
pass
|
|
|
|
@abstractmethod
|
|
def user_endpoint(self):
|
|
""" The endpoint at which user information can be looked up. """
|
|
pass
|
|
|
|
@abstractmethod
|
|
def validate_client_id_and_secret(self, http_client, app_config):
|
|
""" Performs validation of the client ID and secret, raising an exception on failure. """
|
|
pass
|
|
|
|
@abstractmethod
|
|
def authorize_endpoint(self):
|
|
""" Endpoint for authorization. """
|
|
pass
|
|
|
|
def requires_form_encoding(self):
|
|
""" Returns True if form encoding is necessary for the exchange_code_for_token call. """
|
|
return False
|
|
|
|
def client_id(self):
|
|
return self.config.get('CLIENT_ID')
|
|
|
|
def client_secret(self):
|
|
return self.config.get('CLIENT_SECRET')
|
|
|
|
def login_binding_field(self):
|
|
""" Returns the name of the field (`username` or `email`) used for auto binding an external
|
|
login service account to an *internal* login service account. For example, if the external
|
|
login service is GitHub and the internal login service is LDAP, a value of `email` here
|
|
will cause login-with-Github to conduct a search (via email) in LDAP for a user, an auto
|
|
bind the external and internal users together. May return None, in which case no binding
|
|
is performing, and login with this external account will simply create a new account in the
|
|
database.
|
|
"""
|
|
return self.config.get('LOGIN_BINDING_FIELD', None)
|
|
|
|
def get_auth_url(self, app_config, redirect_suffix, csrf_token, scopes):
|
|
""" Retrieves the authorization URL for this login service. """
|
|
redirect_uri = '%s/oauth2/%s/callback%s' % (get_app_url(app_config), self.service_id(),
|
|
redirect_suffix)
|
|
params = {
|
|
'client_id': self.client_id(),
|
|
'redirect_uri': redirect_uri,
|
|
'scope': ' '.join(scopes),
|
|
'state': csrf_token,
|
|
}
|
|
|
|
authorize_url = '%s%s' % (self.authorize_endpoint(), urllib.urlencode(params))
|
|
return authorize_url
|
|
|
|
def get_redirect_uri(self, app_config, redirect_suffix=''):
|
|
return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
|
|
app_config['SERVER_HOSTNAME'],
|
|
self.service_id(),
|
|
redirect_suffix)
|
|
|
|
def get_user_info(self, http_client, token):
|
|
token_param = {
|
|
'access_token': token,
|
|
'alt': 'json',
|
|
}
|
|
|
|
got_user = http_client.get(self.user_endpoint(), params=token_param)
|
|
if got_user.status_code // 100 != 2:
|
|
raise OAuthGetUserInfoException('Non-2XX response code for user_info call: %s' %
|
|
got_user.status_code)
|
|
|
|
user_info = got_user.json()
|
|
if user_info is None:
|
|
raise OAuthGetUserInfoException()
|
|
|
|
return user_info
|
|
|
|
def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
|
|
redirect_suffix='', client_auth=False):
|
|
""" Exchanges an OAuth access code for the associated OAuth token. """
|
|
json_data = self.exchange_code(app_config, http_client, code, form_encode, redirect_suffix,
|
|
client_auth)
|
|
|
|
access_token = json_data.get('access_token', None)
|
|
if access_token is None:
|
|
logger.debug('Got successful get_access_token response with missing token: %s', json_data)
|
|
raise OAuthExchangeCodeException('Missing `access_token` in OAuth response')
|
|
|
|
return access_token
|
|
|
|
def exchange_code(self, app_config, http_client, code, form_encode=False, redirect_suffix='',
|
|
client_auth=False):
|
|
""" Exchanges an OAuth access code for associated OAuth token and other data. """
|
|
payload = {
|
|
'code': code,
|
|
'grant_type': 'authorization_code',
|
|
'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix)
|
|
}
|
|
|
|
headers = {
|
|
'Accept': 'application/json'
|
|
}
|
|
|
|
auth = None
|
|
if client_auth:
|
|
auth = (self.client_id(), self.client_secret())
|
|
else:
|
|
payload['client_id'] = self.client_id()
|
|
payload['client_secret'] = self.client_secret()
|
|
|
|
token_url = self.token_endpoint()
|
|
if form_encode:
|
|
get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
|
|
else:
|
|
get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)
|
|
|
|
if get_access_token.status_code // 100 != 2:
|
|
logger.debug('Got get_access_token response %s', get_access_token.text)
|
|
raise OAuthExchangeCodeException('Got non-2XX response for code exchange: %s' %
|
|
get_access_token.status_code)
|
|
|
|
json_data = get_access_token.json()
|
|
if not json_data:
|
|
raise OAuthExchangeCodeException('Got non-JSON response for code exchange')
|
|
|
|
if 'error' in json_data:
|
|
raise OAuthExchangeCodeException(json_data.get('error_description', json_data['error']))
|
|
|
|
return json_data
|