diff --git a/data/database.py b/data/database.py index b3d2eafff..2defc7b56 100644 --- a/data/database.py +++ b/data/database.py @@ -287,7 +287,7 @@ class OAuthAuthorizationCode(BaseModel): application = ForeignKeyField(OAuthApplication) code = CharField(index=True) scope = CharField() - data = CharField(default=random_string_generator()) + data = CharField() # Context for the code, such as the user class OAuthAccessToken(BaseModel): @@ -298,7 +298,7 @@ class OAuthAccessToken(BaseModel): token_type = CharField(default='Bearer') expires_at = DateTimeField() refresh_token = CharField(index=True, null=True) - data = CharField() # What the hell is this field for? + data = TextField() # This is context for which this token was generated, such as the user class NotificationKind(BaseModel): diff --git a/data/model/oauth.py b/data/model/oauth.py index 2e83019e4..037d0c53c 100644 --- a/data/model/oauth.py +++ b/data/model/oauth.py @@ -1,4 +1,5 @@ import logging +import json from datetime import datetime, timedelta from oauth2lib.provider import AuthorizationProvider @@ -6,6 +7,7 @@ from oauth2lib import utils from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User, random_string_generator) +from data.model.legacy import get_user from auth import scopes @@ -16,6 +18,9 @@ class DatabaseAuthorizationProvider(AuthorizationProvider): def get_authorized_user(self): raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.') + def _generate_data_string(self): + return json.dumps({'username': self.get_authorized_user().username}) + def validate_client_id(self, client_id): return self.get_application_for_client_id(client_id) is not None @@ -75,6 +80,7 @@ class DatabaseAuthorizationProvider(AuthorizationProvider): .where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code, OAuthAuthorizationCode.scope == scope) .get()) + logger.debug('Returning data: %s', found.data) return found.data except OAuthAuthorizationCode.DoesNotExist: return None @@ -94,19 +100,23 @@ class DatabaseAuthorizationProvider(AuthorizationProvider): def persist_authorization_code(self, client_id, code, scope): app = OAuthApplication.get(client_id=client_id) - OAuthAuthorizationCode.create(application=app, code=code, scope=scope) + data = self._generate_data_string() + OAuthAuthorizationCode.create(application=app, code=code, scope=scope, data=data) def persist_token_information(self, client_id, scope, access_token, token_type, expires_in, refresh_token, data): + user = get_user(json.loads(data)['username']) + if not user: + raise RuntimeError('Username must be in the data field') + app = OAuthApplication.get(client_id=client_id) - user = self.get_authorized_user() expires_at = datetime.now() + timedelta(seconds=expires_in) OAuthAccessToken.create(application=app, authorized_user=user, scope=scope, access_token=access_token, token_type=token_type, expires_at=expires_at, refresh_token=refresh_token, data=data) def discard_authorization_code(self, client_id, code): - found = (AuthorizationCode + found = (OAuthAuthorizationCode .select() .join(OAuthApplication) .where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code) @@ -172,9 +182,10 @@ class DatabaseAuthorizationProvider(AuthorizationProvider): expires_in = self.token_expires_in refresh_token = None # No refresh token for this kind of flow + data = self._generate_data_string() self.persist_token_information(client_id=client_id, scope=scope, access_token=access_token, token_type=token_type, expires_in=expires_in, - refresh_token=refresh_token, data='') + refresh_token=refresh_token, data=data) url = utils.build_url(redirect_uri, params) url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in) @@ -182,7 +193,8 @@ class DatabaseAuthorizationProvider(AuthorizationProvider): return self._make_response(headers={'Location': url}, status_code=302) def create_application(org, name, application_uri, redirect_uri, **kwargs): - return OAuthApplication.create(organization=org, name=name, application_uri=application_uri, redirect_uri=redirect_uri, **kwargs) + return OAuthApplication.create(organization=org, name=name, application_uri=application_uri, + redirect_uri=redirect_uri, **kwargs) def validate_access_token(access_token): try: diff --git a/endpoints/web.py b/endpoints/web.py index 6b9a3e2e0..a99e132b8 100644 --- a/endpoints/web.py +++ b/endpoints/web.py @@ -331,3 +331,17 @@ def request_authorization_code(): return provider.get_token_response(response_type, client_id, redirect_uri, scope=scope) else: return provider.get_authorization_code(response_type, client_id, redirect_uri, scope=scope) + + +@web.route('/oauth/access_token', methods=['POST']) +@no_cache +def exchange_code_for_token(): + grant_type = request.form.get('grant_type', None) + client_id = request.form.get('client_id', None) + client_secret = request.form.get('client_secret', None) + redirect_uri = request.form.get('redirect_uri', None) + code = request.form.get('code', None) + scope = request.form.get('scope', None) + + provider = FlaskAuthorizationProvider() + return provider.get_token(grant_type, client_id, client_secret, redirect_uri, code, scope=scope) diff --git a/test/data/test.db b/test/data/test.db index 4e24e334a..e55773dc0 100644 Binary files a/test/data/test.db and b/test/data/test.db differ