258 lines
8.9 KiB
Python
258 lines
8.9 KiB
Python
import logging
|
|
|
|
from datetime import datetime, timedelta
|
|
from oauth2lib.provider import AuthorizationProvider
|
|
from oauth2lib import utils
|
|
|
|
from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User,
|
|
random_string_generator)
|
|
from auth import scopes
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|
def get_authorized_user(self):
|
|
raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.')
|
|
|
|
def validate_client_id(self, client_id):
|
|
return self.get_application_for_client_id(client_id) is not None
|
|
|
|
def get_application_for_client_id(self, client_id):
|
|
try:
|
|
return OAuthApplication.get(client_id=client_id)
|
|
except OAuthApplication.DoesNotExist:
|
|
return None
|
|
|
|
def validate_client_secret(self, client_id, client_secret):
|
|
try:
|
|
OAuthApplication.get(client_id=client_id, client_secret=client_secret)
|
|
return True
|
|
except OAuthApplication.DoesNotExist:
|
|
return False
|
|
|
|
def validate_redirect_uri(self, client_id, redirect_uri):
|
|
try:
|
|
app = OAuthApplication.get(client_id=client_id)
|
|
if app.redirect_uri and redirect_uri.startswith(app.redirect_uri):
|
|
return True
|
|
return False
|
|
except OAuthApplication.DoesNotExist:
|
|
return False
|
|
|
|
def validate_scope(self, client_id, scopes_string):
|
|
return scopes.validate_scope_string(scopes_string)
|
|
|
|
def validate_access(self):
|
|
return self.get_authorized_user() is not None
|
|
|
|
def load_authorized_scope_string(self, client_id, username):
|
|
found = (OAuthAccessToken
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.switch(OAuthAccessToken)
|
|
.join(User)
|
|
.where(OAuthApplication.client_id == client_id, User.username == username,
|
|
OAuthAccessToken.expires_at > datetime.now()))
|
|
found = list(found)
|
|
logger.debug('Found %s matching tokens.', len(found))
|
|
long_scope_string = ','.join([token.scope for token in found])
|
|
logger.debug('Computed long scope string: %s', long_scope_string)
|
|
return long_scope_string
|
|
|
|
def validate_has_scopes(self, client_id, username, scope):
|
|
long_scope_string = self.load_authorized_scope_string(client_id, username)
|
|
|
|
# Make sure the token contains the given scopes (at least).
|
|
return scopes.is_subset_string(long_scope_string, scope)
|
|
|
|
def from_authorization_code(self, client_id, code, scope):
|
|
try:
|
|
found = (OAuthAuthorizationCode
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code,
|
|
OAuthAuthorizationCode.scope == scope)
|
|
.get())
|
|
return found.data
|
|
except OAuthAuthorizationCode.DoesNotExist:
|
|
return None
|
|
|
|
def from_refresh_token(self, client_id, refresh_token, scope):
|
|
try:
|
|
found = (OAuthAccessToken
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.where(OAuthApplication.client_id == client_id,
|
|
OAuthAccessToken.refresh_token == refresh_token,
|
|
OAuthAccessToken.scope == scope)
|
|
.get())
|
|
return found.data
|
|
except OAuthAccessToken.DoesNotExist:
|
|
return None
|
|
|
|
def persist_authorization_code(self, client_id, code, scope):
|
|
app = OAuthApplication.get(client_id=client_id)
|
|
OAuthAuthorizationCode.create(application=app, code=code, scope=scope)
|
|
|
|
def persist_token_information(self, client_id, scope, access_token, token_type, expires_in,
|
|
refresh_token, data):
|
|
app = OAuthApplication.get(client_id=client_id)
|
|
user = self.get_authorized_user()
|
|
expires_at = datetime.now() + timedelta(seconds=expires_in)
|
|
OAuthAccessToken.create(application=app, authorized_user=user, scope=scope,
|
|
access_token=access_token, token_type=token_type,
|
|
expires_at=expires_at, refresh_token=refresh_token, data=data)
|
|
|
|
def discard_authorization_code(self, client_id, code):
|
|
found = (AuthorizationCode
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code)
|
|
.get())
|
|
found.delete_instance()
|
|
|
|
def discard_refresh_token(self, client_id, refresh_token):
|
|
found = (AccessToken
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.where(OAuthApplication.client_id == client_id,
|
|
OAuthAccessToken.refresh_token == refresh_token)
|
|
.get())
|
|
found.delete_instance()
|
|
|
|
|
|
def get_auth_denied_response(self, response_type, client_id, redirect_uri, **params):
|
|
# Ensure proper response_type
|
|
if response_type != 'token':
|
|
err = 'unsupported_response_type'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
# Check redirect URI
|
|
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
|
|
if not is_valid_redirect_uri:
|
|
return self._invalid_redirect_uri_response()
|
|
|
|
return self._make_redirect_error_response(redirect_uri, 'authorization_denied')
|
|
|
|
|
|
def get_token_response(self, response_type, client_id, redirect_uri, **params):
|
|
# Ensure proper response_type
|
|
if response_type != 'token':
|
|
err = 'unsupported_response_type'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
# Check redirect URI
|
|
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
|
|
if not is_valid_redirect_uri:
|
|
return self._invalid_redirect_uri_response()
|
|
|
|
# Check conditions
|
|
is_valid_client_id = self.validate_client_id(client_id)
|
|
is_valid_access = self.validate_access()
|
|
scope = params.get('scope', '')
|
|
are_valid_scopes = self.validate_scope(client_id, scope)
|
|
|
|
# Return proper error responses on invalid conditions
|
|
if not is_valid_client_id:
|
|
err = 'unauthorized_client'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
if not is_valid_access:
|
|
err = 'access_denied'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
if not are_valid_scopes:
|
|
err = 'invalid_scope'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
access_token = self.generate_access_token()
|
|
token_type = self.token_type
|
|
expires_in = self.token_expires_in
|
|
refresh_token = None # No refresh token for this kind of flow
|
|
|
|
self.persist_token_information(client_id=client_id, scope=scope, access_token=access_token,
|
|
token_type=token_type, expires_in=expires_in,
|
|
refresh_token=refresh_token, data='')
|
|
|
|
url = utils.build_url(redirect_uri, params)
|
|
url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in)
|
|
|
|
return self._make_response(headers={'Location': url}, status_code=302)
|
|
|
|
def create_application(org, name, application_uri, redirect_uri, **kwargs):
|
|
return OAuthApplication.create(organization=org, name=name, application_uri=application_uri, redirect_uri=redirect_uri, **kwargs)
|
|
|
|
def validate_access_token(access_token):
|
|
try:
|
|
found = (OAuthAccessToken
|
|
.select(OAuthAccessToken, User)
|
|
.join(User)
|
|
.where(OAuthAccessToken.access_token == access_token)
|
|
.get())
|
|
return found
|
|
except OAuthAccessToken.DoesNotExist:
|
|
return None
|
|
|
|
def get_application_for_client_id(client_id):
|
|
try:
|
|
return OAuthApplication.get(client_id=client_id)
|
|
except OAuthApplication.DoesNotExist:
|
|
return None
|
|
|
|
def reset_client_secret(application):
|
|
application.client_secret = random_string_generator(length=40)()
|
|
application.save()
|
|
return application
|
|
|
|
def lookup_application(org, client_id):
|
|
try:
|
|
return OAuthApplication.get(organization = org, client_id=client_id)
|
|
except OAuthApplication.DoesNotExist:
|
|
return None
|
|
|
|
|
|
def delete_application(org, client_id):
|
|
application = lookup_application(org, client_id)
|
|
if not application:
|
|
return
|
|
|
|
application.delete_instance(recursive=True, delete_nullable=True)
|
|
return application
|
|
|
|
|
|
def lookup_access_token_for_user(user, token_uuid):
|
|
try:
|
|
return OAuthAccessToken.get(OAuthAccessToken.authorized_user == user,
|
|
OAuthAccessToken.uuid == token_uuid)
|
|
except OAuthAccessToken.DoesNotExist:
|
|
return None
|
|
|
|
|
|
def list_access_tokens_for_user(user):
|
|
query = (OAuthAccessToken
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.switch(OAuthAccessToken)
|
|
.join(User)
|
|
.where(OAuthAccessToken.authorized_user == user))
|
|
|
|
return query
|
|
|
|
|
|
def list_applications_for_org(org):
|
|
query = (OAuthApplication
|
|
.select()
|
|
.join(User)
|
|
.where(OAuthApplication.organization == org))
|
|
|
|
return query
|
|
|
|
|
|
def create_access_token_for_testing(user, client_id, scope):
|
|
expires_at = datetime.now() + timedelta(seconds=10000)
|
|
application = get_application_for_client_id(client_id)
|
|
OAuthAccessToken.create(application=application, authorized_user=user, scope=scope,
|
|
token_type='token', access_token='test',
|
|
expires_at=expires_at, refresh_token='', data='')
|