Finally figure out what the data field is supposed to be for and use it to implement and fix 3LO.
This commit is contained in:
parent
e92cf37583
commit
cbc40588cb
4 changed files with 33 additions and 7 deletions
|
@ -287,7 +287,7 @@ class OAuthAuthorizationCode(BaseModel):
|
||||||
application = ForeignKeyField(OAuthApplication)
|
application = ForeignKeyField(OAuthApplication)
|
||||||
code = CharField(index=True)
|
code = CharField(index=True)
|
||||||
scope = CharField()
|
scope = CharField()
|
||||||
data = CharField(default=random_string_generator())
|
data = CharField() # Context for the code, such as the user
|
||||||
|
|
||||||
|
|
||||||
class OAuthAccessToken(BaseModel):
|
class OAuthAccessToken(BaseModel):
|
||||||
|
@ -298,7 +298,7 @@ class OAuthAccessToken(BaseModel):
|
||||||
token_type = CharField(default='Bearer')
|
token_type = CharField(default='Bearer')
|
||||||
expires_at = DateTimeField()
|
expires_at = DateTimeField()
|
||||||
refresh_token = CharField(index=True, null=True)
|
refresh_token = CharField(index=True, null=True)
|
||||||
data = CharField() # What the hell is this field for?
|
data = TextField() # This is context for which this token was generated, such as the user
|
||||||
|
|
||||||
|
|
||||||
class NotificationKind(BaseModel):
|
class NotificationKind(BaseModel):
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import json
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from oauth2lib.provider import AuthorizationProvider
|
from oauth2lib.provider import AuthorizationProvider
|
||||||
|
@ -6,6 +7,7 @@ from oauth2lib import utils
|
||||||
|
|
||||||
from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User,
|
from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User,
|
||||||
random_string_generator)
|
random_string_generator)
|
||||||
|
from data.model.legacy import get_user
|
||||||
from auth import scopes
|
from auth import scopes
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,6 +18,9 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
||||||
def get_authorized_user(self):
|
def get_authorized_user(self):
|
||||||
raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.')
|
raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.')
|
||||||
|
|
||||||
|
def _generate_data_string(self):
|
||||||
|
return json.dumps({'username': self.get_authorized_user().username})
|
||||||
|
|
||||||
def validate_client_id(self, client_id):
|
def validate_client_id(self, client_id):
|
||||||
return self.get_application_for_client_id(client_id) is not None
|
return self.get_application_for_client_id(client_id) is not None
|
||||||
|
|
||||||
|
@ -75,6 +80,7 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
||||||
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code,
|
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code,
|
||||||
OAuthAuthorizationCode.scope == scope)
|
OAuthAuthorizationCode.scope == scope)
|
||||||
.get())
|
.get())
|
||||||
|
logger.debug('Returning data: %s', found.data)
|
||||||
return found.data
|
return found.data
|
||||||
except OAuthAuthorizationCode.DoesNotExist:
|
except OAuthAuthorizationCode.DoesNotExist:
|
||||||
return None
|
return None
|
||||||
|
@ -94,19 +100,23 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
||||||
|
|
||||||
def persist_authorization_code(self, client_id, code, scope):
|
def persist_authorization_code(self, client_id, code, scope):
|
||||||
app = OAuthApplication.get(client_id=client_id)
|
app = OAuthApplication.get(client_id=client_id)
|
||||||
OAuthAuthorizationCode.create(application=app, code=code, scope=scope)
|
data = self._generate_data_string()
|
||||||
|
OAuthAuthorizationCode.create(application=app, code=code, scope=scope, data=data)
|
||||||
|
|
||||||
def persist_token_information(self, client_id, scope, access_token, token_type, expires_in,
|
def persist_token_information(self, client_id, scope, access_token, token_type, expires_in,
|
||||||
refresh_token, data):
|
refresh_token, data):
|
||||||
|
user = get_user(json.loads(data)['username'])
|
||||||
|
if not user:
|
||||||
|
raise RuntimeError('Username must be in the data field')
|
||||||
|
|
||||||
app = OAuthApplication.get(client_id=client_id)
|
app = OAuthApplication.get(client_id=client_id)
|
||||||
user = self.get_authorized_user()
|
|
||||||
expires_at = datetime.now() + timedelta(seconds=expires_in)
|
expires_at = datetime.now() + timedelta(seconds=expires_in)
|
||||||
OAuthAccessToken.create(application=app, authorized_user=user, scope=scope,
|
OAuthAccessToken.create(application=app, authorized_user=user, scope=scope,
|
||||||
access_token=access_token, token_type=token_type,
|
access_token=access_token, token_type=token_type,
|
||||||
expires_at=expires_at, refresh_token=refresh_token, data=data)
|
expires_at=expires_at, refresh_token=refresh_token, data=data)
|
||||||
|
|
||||||
def discard_authorization_code(self, client_id, code):
|
def discard_authorization_code(self, client_id, code):
|
||||||
found = (AuthorizationCode
|
found = (OAuthAuthorizationCode
|
||||||
.select()
|
.select()
|
||||||
.join(OAuthApplication)
|
.join(OAuthApplication)
|
||||||
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code)
|
.where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code)
|
||||||
|
@ -172,9 +182,10 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
||||||
expires_in = self.token_expires_in
|
expires_in = self.token_expires_in
|
||||||
refresh_token = None # No refresh token for this kind of flow
|
refresh_token = None # No refresh token for this kind of flow
|
||||||
|
|
||||||
|
data = self._generate_data_string()
|
||||||
self.persist_token_information(client_id=client_id, scope=scope, access_token=access_token,
|
self.persist_token_information(client_id=client_id, scope=scope, access_token=access_token,
|
||||||
token_type=token_type, expires_in=expires_in,
|
token_type=token_type, expires_in=expires_in,
|
||||||
refresh_token=refresh_token, data='')
|
refresh_token=refresh_token, data=data)
|
||||||
|
|
||||||
url = utils.build_url(redirect_uri, params)
|
url = utils.build_url(redirect_uri, params)
|
||||||
url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in)
|
url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in)
|
||||||
|
@ -182,7 +193,8 @@ class DatabaseAuthorizationProvider(AuthorizationProvider):
|
||||||
return self._make_response(headers={'Location': url}, status_code=302)
|
return self._make_response(headers={'Location': url}, status_code=302)
|
||||||
|
|
||||||
def create_application(org, name, application_uri, redirect_uri, **kwargs):
|
def create_application(org, name, application_uri, redirect_uri, **kwargs):
|
||||||
return OAuthApplication.create(organization=org, name=name, application_uri=application_uri, redirect_uri=redirect_uri, **kwargs)
|
return OAuthApplication.create(organization=org, name=name, application_uri=application_uri,
|
||||||
|
redirect_uri=redirect_uri, **kwargs)
|
||||||
|
|
||||||
def validate_access_token(access_token):
|
def validate_access_token(access_token):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -331,3 +331,17 @@ def request_authorization_code():
|
||||||
return provider.get_token_response(response_type, client_id, redirect_uri, scope=scope)
|
return provider.get_token_response(response_type, client_id, redirect_uri, scope=scope)
|
||||||
else:
|
else:
|
||||||
return provider.get_authorization_code(response_type, client_id, redirect_uri, scope=scope)
|
return provider.get_authorization_code(response_type, client_id, redirect_uri, scope=scope)
|
||||||
|
|
||||||
|
|
||||||
|
@web.route('/oauth/access_token', methods=['POST'])
|
||||||
|
@no_cache
|
||||||
|
def exchange_code_for_token():
|
||||||
|
grant_type = request.form.get('grant_type', None)
|
||||||
|
client_id = request.form.get('client_id', None)
|
||||||
|
client_secret = request.form.get('client_secret', None)
|
||||||
|
redirect_uri = request.form.get('redirect_uri', None)
|
||||||
|
code = request.form.get('code', None)
|
||||||
|
scope = request.form.get('scope', None)
|
||||||
|
|
||||||
|
provider = FlaskAuthorizationProvider()
|
||||||
|
return provider.get_token(grant_type, client_id, client_secret, redirect_uri, code, scope=scope)
|
||||||
|
|
Binary file not shown.
Reference in a new issue