651666b60b
Breaks out the validation code from the auth context modification calls, makes decorators easier to define and adds testing for each individual piece. Will be the basis of better error messaging in the following change.
293 lines
No EOL
10 KiB
Python
293 lines
No EOL
10 KiB
Python
import logging
|
|
import json
|
|
|
|
from flask import url_for
|
|
from datetime import datetime, timedelta
|
|
from oauth2lib.provider import AuthorizationProvider
|
|
from oauth2lib import utils
|
|
|
|
from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User,
|
|
AccessToken, random_string_generator)
|
|
from data.model import user, config
|
|
from auth import scopes
|
|
from util import get_app_url
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|
def get_authorized_user(self):
|
|
raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.')
|
|
|
|
def _generate_data_string(self):
|
|
return json.dumps({'username': self.get_authorized_user().username})
|
|
|
|
@property
|
|
def token_expires_in(self):
|
|
"""Property method to get the token expiration time in seconds.
|
|
"""
|
|
return int(60*60*24*365.25*10) # 10 Years
|
|
|
|
def validate_client_id(self, client_id):
|
|
return self.get_application_for_client_id(client_id) is not None
|
|
|
|
def get_application_for_client_id(self, client_id):
|
|
try:
|
|
return OAuthApplication.get(client_id=client_id)
|
|
except OAuthApplication.DoesNotExist:
|
|
return None
|
|
|
|
def validate_client_secret(self, client_id, client_secret):
|
|
try:
|
|
OAuthApplication.get(client_id=client_id, client_secret=client_secret)
|
|
return True
|
|
except OAuthApplication.DoesNotExist:
|
|
return False
|
|
|
|
def validate_redirect_uri(self, client_id, redirect_uri):
|
|
internal_redirect_url = '%s%s' % (get_app_url(config.app_config),
|
|
url_for('web.oauth_local_handler'))
|
|
|
|
if redirect_uri == internal_redirect_url:
|
|
return True
|
|
|
|
try:
|
|
oauth_app = OAuthApplication.get(client_id=client_id)
|
|
if (oauth_app.redirect_uri and redirect_uri and
|
|
redirect_uri.startswith(oauth_app.redirect_uri)):
|
|
return True
|
|
return False
|
|
except OAuthApplication.DoesNotExist:
|
|
return False
|
|
|
|
def validate_scope(self, client_id, scopes_string):
|
|
return scopes.validate_scope_string(scopes_string)
|
|
|
|
def validate_access(self):
|
|
return self.get_authorized_user() is not None
|
|
|
|
def load_authorized_scope_string(self, client_id, username):
|
|
found = (OAuthAccessToken
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.switch(OAuthAccessToken)
|
|
.join(User)
|
|
.where(OAuthApplication.client_id == client_id, User.username == username,
|
|
OAuthAccessToken.expires_at > datetime.utcnow()))
|
|
found = list(found)
|
|
logger.debug('Found %s matching tokens.', len(found))
|
|
long_scope_string = ','.join([token.scope for token in found])
|
|
logger.debug('Computed long scope string: %s', long_scope_string)
|
|
return long_scope_string
|
|
|
|
def validate_has_scopes(self, client_id, username, scope):
|
|
long_scope_string = self.load_authorized_scope_string(client_id, username)
|
|
|
|
# Make sure the token contains the given scopes (at least).
|
|
return scopes.is_subset_string(long_scope_string, scope)
|
|
|
|
def from_authorization_code(self, client_id, code, scope):
|
|
try:
|
|
found = (OAuthAuthorizationCode
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code,
|
|
OAuthAuthorizationCode.scope == scope)
|
|
.get())
|
|
logger.debug('Returning data: %s', found.data)
|
|
return found.data
|
|
except OAuthAuthorizationCode.DoesNotExist:
|
|
return None
|
|
|
|
def from_refresh_token(self, client_id, refresh_token, scope):
|
|
try:
|
|
found = (OAuthAccessToken
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.where(OAuthApplication.client_id == client_id,
|
|
OAuthAccessToken.refresh_token == refresh_token,
|
|
OAuthAccessToken.scope == scope)
|
|
.get())
|
|
return found.data
|
|
except OAuthAccessToken.DoesNotExist:
|
|
return None
|
|
|
|
def persist_authorization_code(self, client_id, code, scope):
|
|
oauth_app = OAuthApplication.get(client_id=client_id)
|
|
data = self._generate_data_string()
|
|
OAuthAuthorizationCode.create(application=oauth_app, code=code, scope=scope, data=data)
|
|
|
|
def persist_token_information(self, client_id, scope, access_token, token_type, expires_in,
|
|
refresh_token, data):
|
|
found = user.get_user(json.loads(data)['username'])
|
|
if not found:
|
|
raise RuntimeError('Username must be in the data field')
|
|
|
|
oauth_app = OAuthApplication.get(client_id=client_id)
|
|
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
|
OAuthAccessToken.create(application=oauth_app, authorized_user=found, scope=scope,
|
|
access_token=access_token, token_type=token_type,
|
|
expires_at=expires_at, refresh_token=refresh_token, data=data)
|
|
|
|
def discard_authorization_code(self, client_id, code):
|
|
found = (OAuthAuthorizationCode
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code)
|
|
.get())
|
|
found.delete_instance()
|
|
|
|
def discard_refresh_token(self, client_id, refresh_token):
|
|
found = (AccessToken
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.where(OAuthApplication.client_id == client_id,
|
|
OAuthAccessToken.refresh_token == refresh_token)
|
|
.get())
|
|
found.delete_instance()
|
|
|
|
|
|
def get_auth_denied_response(self, response_type, client_id, redirect_uri, **params):
|
|
# Ensure proper response_type
|
|
if response_type != 'token':
|
|
err = 'unsupported_response_type'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
# Check redirect URI
|
|
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
|
|
if not is_valid_redirect_uri:
|
|
return self._invalid_redirect_uri_response()
|
|
|
|
return self._make_redirect_error_response(redirect_uri, 'authorization_denied')
|
|
|
|
|
|
def get_token_response(self, response_type, client_id, redirect_uri, **params):
|
|
# Ensure proper response_type
|
|
if response_type != 'token':
|
|
err = 'unsupported_response_type'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
# Check for a valid client ID.
|
|
is_valid_client_id = self.validate_client_id(client_id)
|
|
if not is_valid_client_id:
|
|
err = 'unauthorized_client'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
# Check for a valid redirect URI.
|
|
is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
|
|
if not is_valid_redirect_uri:
|
|
return self._invalid_redirect_uri_response()
|
|
|
|
# Check conditions
|
|
is_valid_access = self.validate_access()
|
|
scope = params.get('scope', '')
|
|
are_valid_scopes = self.validate_scope(client_id, scope)
|
|
|
|
# Return proper error responses on invalid conditions
|
|
if not is_valid_access:
|
|
err = 'access_denied'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
if not are_valid_scopes:
|
|
err = 'invalid_scope'
|
|
return self._make_redirect_error_response(redirect_uri, err)
|
|
|
|
access_token = self.generate_access_token()
|
|
token_type = self.token_type
|
|
expires_in = self.token_expires_in
|
|
refresh_token = None # No refresh token for this kind of flow
|
|
|
|
data = self._generate_data_string()
|
|
self.persist_token_information(client_id=client_id, scope=scope, access_token=access_token,
|
|
token_type=token_type, expires_in=expires_in,
|
|
refresh_token=refresh_token, data=data)
|
|
|
|
url = utils.build_url(redirect_uri, params)
|
|
url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in)
|
|
|
|
return self._make_response(headers={'Location': url}, status_code=302)
|
|
|
|
|
|
def create_application(org, name, application_uri, redirect_uri, **kwargs):
|
|
return OAuthApplication.create(organization=org, name=name, application_uri=application_uri,
|
|
redirect_uri=redirect_uri, **kwargs)
|
|
|
|
|
|
def validate_access_token(access_token):
|
|
try:
|
|
found = (OAuthAccessToken
|
|
.select(OAuthAccessToken, User)
|
|
.join(User)
|
|
.where(OAuthAccessToken.access_token == access_token)
|
|
.get())
|
|
return found
|
|
except OAuthAccessToken.DoesNotExist:
|
|
return None
|
|
|
|
|
|
def get_application_for_client_id(client_id):
|
|
try:
|
|
return OAuthApplication.get(client_id=client_id)
|
|
except OAuthApplication.DoesNotExist:
|
|
return None
|
|
|
|
|
|
def reset_client_secret(application):
|
|
application.client_secret = random_string_generator(length=40)()
|
|
application.save()
|
|
return application
|
|
|
|
|
|
def lookup_application(org, client_id):
|
|
try:
|
|
return OAuthApplication.get(organization=org, client_id=client_id)
|
|
except OAuthApplication.DoesNotExist:
|
|
return None
|
|
|
|
|
|
def delete_application(org, client_id):
|
|
application = lookup_application(org, client_id)
|
|
if not application:
|
|
return
|
|
|
|
application.delete_instance(recursive=True, delete_nullable=True)
|
|
return application
|
|
|
|
|
|
def lookup_access_token_for_user(user_obj, token_uuid):
|
|
try:
|
|
return OAuthAccessToken.get(OAuthAccessToken.authorized_user == user_obj,
|
|
OAuthAccessToken.uuid == token_uuid)
|
|
except OAuthAccessToken.DoesNotExist:
|
|
return None
|
|
|
|
|
|
def list_access_tokens_for_user(user_obj):
|
|
query = (OAuthAccessToken
|
|
.select()
|
|
.join(OAuthApplication)
|
|
.switch(OAuthAccessToken)
|
|
.join(User)
|
|
.where(OAuthAccessToken.authorized_user == user_obj))
|
|
|
|
return query
|
|
|
|
|
|
def list_applications_for_org(org):
|
|
query = (OAuthApplication
|
|
.select()
|
|
.join(User)
|
|
.where(OAuthApplication.organization == org))
|
|
|
|
return query
|
|
|
|
|
|
def create_access_token_for_testing(user_obj, client_id, scope, access_token='test',
|
|
expires_in=10000):
|
|
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
|
application = get_application_for_client_id(client_id)
|
|
created = OAuthAccessToken.create(application=application, authorized_user=user_obj, scope=scope,
|
|
token_type='token', access_token=access_token,
|
|
expires_at=expires_at, refresh_token='', data='')
|
|
return created |