Accidental refactor, split out legacy.py into separate sumodules and update all call sites.
This commit is contained in:
parent
2109d24483
commit
3efaa255e8
92 changed files with 4458 additions and 4269 deletions
|
@ -7,13 +7,14 @@ from oauth2lib.provider import AuthorizationProvider
|
|||
from oauth2lib import utils
|
||||
|
||||
from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User,
|
||||
random_string_generator)
|
||||
from data.model.legacy import get_user
|
||||
AccessToken, random_string_generator)
|
||||
from data.model import user
|
||||
from auth import scopes
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatabaseAuthorizationProvider(AuthorizationProvider):
|
||||
def get_authorized_user(self):
|
||||
raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.')
|
||||
|
@ -49,7 +50,8 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|||
|
||||
try:
|
||||
oauth_app = OAuthApplication.get(client_id=client_id)
|
||||
if oauth_app.redirect_uri and redirect_uri and redirect_uri.startswith(oauth_app.redirect_uri):
|
||||
if (oauth_app.redirect_uri and redirect_uri and
|
||||
redirect_uri.startswith(oauth_app.redirect_uri)):
|
||||
return True
|
||||
return False
|
||||
except OAuthApplication.DoesNotExist:
|
||||
|
@ -63,12 +65,12 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|||
|
||||
def load_authorized_scope_string(self, client_id, username):
|
||||
found = (OAuthAccessToken
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.switch(OAuthAccessToken)
|
||||
.join(User)
|
||||
.where(OAuthApplication.client_id == client_id, User.username == username,
|
||||
OAuthAccessToken.expires_at > datetime.utcnow()))
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.switch(OAuthAccessToken)
|
||||
.join(User)
|
||||
.where(OAuthApplication.client_id == client_id, User.username == username,
|
||||
OAuthAccessToken.expires_at > datetime.utcnow()))
|
||||
found = list(found)
|
||||
logger.debug('Found %s matching tokens.', len(found))
|
||||
long_scope_string = ','.join([token.scope for token in found])
|
||||
|
@ -84,11 +86,11 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|||
def from_authorization_code(self, client_id, code, scope):
|
||||
try:
|
||||
found = (OAuthAuthorizationCode
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.scope == scope)
|
||||
.get())
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code,
|
||||
OAuthAuthorizationCode.scope == scope)
|
||||
.get())
|
||||
logger.debug('Returning data: %s', found.data)
|
||||
return found.data
|
||||
except OAuthAuthorizationCode.DoesNotExist:
|
||||
|
@ -97,12 +99,12 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|||
def from_refresh_token(self, client_id, refresh_token, scope):
|
||||
try:
|
||||
found = (OAuthAccessToken
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.where(OAuthApplication.client_id == client_id,
|
||||
OAuthAccessToken.refresh_token == refresh_token,
|
||||
OAuthAccessToken.scope == scope)
|
||||
.get())
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.where(OAuthApplication.client_id == client_id,
|
||||
OAuthAccessToken.refresh_token == refresh_token,
|
||||
OAuthAccessToken.scope == scope)
|
||||
.get())
|
||||
return found.data
|
||||
except OAuthAccessToken.DoesNotExist:
|
||||
return None
|
||||
|
@ -114,31 +116,31 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|||
|
||||
def persist_token_information(self, client_id, scope, access_token, token_type, expires_in,
|
||||
refresh_token, data):
|
||||
user = get_user(json.loads(data)['username'])
|
||||
if not user:
|
||||
found = user.get_user(json.loads(data)['username'])
|
||||
if not found:
|
||||
raise RuntimeError('Username must be in the data field')
|
||||
|
||||
oauth_app = OAuthApplication.get(client_id=client_id)
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
|
||||
OAuthAccessToken.create(application=oauth_app, authorized_user=user, scope=scope,
|
||||
OAuthAccessToken.create(application=oauth_app, authorized_user=found, scope=scope,
|
||||
access_token=access_token, token_type=token_type,
|
||||
expires_at=expires_at, refresh_token=refresh_token, data=data)
|
||||
|
||||
def discard_authorization_code(self, client_id, code):
|
||||
found = (OAuthAuthorizationCode
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code)
|
||||
.get())
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code)
|
||||
.get())
|
||||
found.delete_instance()
|
||||
|
||||
def discard_refresh_token(self, client_id, refresh_token):
|
||||
found = (AccessToken
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.where(OAuthApplication.client_id == client_id,
|
||||
OAuthAccessToken.refresh_token == refresh_token)
|
||||
.get())
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.where(OAuthApplication.client_id == client_id,
|
||||
OAuthAccessToken.refresh_token == refresh_token)
|
||||
.get())
|
||||
found.delete_instance()
|
||||
|
||||
|
||||
|
@ -157,7 +159,6 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
|||
|
||||
|
||||
def get_token_response(self, response_type, client_id, redirect_uri, **params):
|
||||
|
||||
# Ensure proper response_type
|
||||
if response_type != 'token':
|
||||
err = 'unsupported_response_type'
|
||||
|
@ -211,10 +212,10 @@ def create_application(org, name, application_uri, redirect_uri, **kwargs):
|
|||
def validate_access_token(access_token):
|
||||
try:
|
||||
found = (OAuthAccessToken
|
||||
.select(OAuthAccessToken, User)
|
||||
.join(User)
|
||||
.where(OAuthAccessToken.access_token == access_token)
|
||||
.get())
|
||||
.select(OAuthAccessToken, User)
|
||||
.join(User)
|
||||
.where(OAuthAccessToken.access_token == access_token)
|
||||
.get())
|
||||
return found
|
||||
except OAuthAccessToken.DoesNotExist:
|
||||
return None
|
||||
|
@ -235,7 +236,7 @@ def reset_client_secret(application):
|
|||
|
||||
def lookup_application(org, client_id):
|
||||
try:
|
||||
return OAuthApplication.get(organization = org, client_id=client_id)
|
||||
return OAuthApplication.get(organization=org, client_id=client_id)
|
||||
except OAuthApplication.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
@ -249,21 +250,21 @@ def delete_application(org, client_id):
|
|||
return application
|
||||
|
||||
|
||||
def lookup_access_token_for_user(user, token_uuid):
|
||||
def lookup_access_token_for_user(user_obj, token_uuid):
|
||||
try:
|
||||
return OAuthAccessToken.get(OAuthAccessToken.authorized_user == user,
|
||||
return OAuthAccessToken.get(OAuthAccessToken.authorized_user == user_obj,
|
||||
OAuthAccessToken.uuid == token_uuid)
|
||||
except OAuthAccessToken.DoesNotExist:
|
||||
return None
|
||||
|
||||
|
||||
def list_access_tokens_for_user(user):
|
||||
def list_access_tokens_for_user(user_obj):
|
||||
query = (OAuthAccessToken
|
||||
.select()
|
||||
.join(OAuthApplication)
|
||||
.switch(OAuthAccessToken)
|
||||
.join(User)
|
||||
.where(OAuthAccessToken.authorized_user == user))
|
||||
.where(OAuthAccessToken.authorized_user == user_obj))
|
||||
|
||||
return query
|
||||
|
||||
|
@ -277,9 +278,9 @@ def list_applications_for_org(org):
|
|||
return query
|
||||
|
||||
|
||||
def create_access_token_for_testing(user, client_id, scope):
|
||||
def create_access_token_for_testing(user_obj, client_id, scope):
|
||||
expires_at = datetime.utcnow() + timedelta(seconds=10000)
|
||||
application = get_application_for_client_id(client_id)
|
||||
OAuthAccessToken.create(application=application, authorized_user=user, scope=scope,
|
||||
OAuthAccessToken.create(application=application, authorized_user=user_obj, scope=scope,
|
||||
token_type='token', access_token='test',
|
||||
expires_at=expires_at, refresh_token='', data='')
|
||||
|
|
Reference in a new issue