import logging import json from datetime import datetime, timedelta from oauth2lib.provider import AuthorizationProvider from oauth2lib import utils from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User, random_string_generator) from data.model.legacy import get_user from auth import scopes from flask import render_template logger = logging.getLogger(__name__) class DatabaseAuthorizationProvider(AuthorizationProvider): def get_authorized_user(self): raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.') def _generate_data_string(self): return json.dumps({'username': self.get_authorized_user().username}) @property def token_expires_in(self): """Property method to get the token expiration time in seconds. """ return int(60*60*24*365.25*10) # 10 Years def validate_client_id(self, client_id): return self.get_application_for_client_id(client_id) is not None def get_application_for_client_id(self, client_id): try: return OAuthApplication.get(client_id=client_id) except OAuthApplication.DoesNotExist: return None def validate_client_secret(self, client_id, client_secret): try: OAuthApplication.get(client_id=client_id, client_secret=client_secret) return True except OAuthApplication.DoesNotExist: return False def validate_redirect_uri(self, client_id, redirect_uri): try: app = OAuthApplication.get(client_id=client_id) if app.redirect_uri and redirect_uri and redirect_uri.startswith(app.redirect_uri): return True return False except OAuthApplication.DoesNotExist: return False def validate_scope(self, client_id, scopes_string): return scopes.validate_scope_string(scopes_string) def validate_access(self): return self.get_authorized_user() is not None def load_authorized_scope_string(self, client_id, username): found = (OAuthAccessToken .select() .join(OAuthApplication) .switch(OAuthAccessToken) .join(User) .where(OAuthApplication.client_id == client_id, User.username == username, OAuthAccessToken.expires_at > datetime.utcnow())) found = list(found) logger.debug('Found %s matching tokens.', len(found)) long_scope_string = ','.join([token.scope for token in found]) logger.debug('Computed long scope string: %s', long_scope_string) return long_scope_string def validate_has_scopes(self, client_id, username, scope): long_scope_string = self.load_authorized_scope_string(client_id, username) # Make sure the token contains the given scopes (at least). return scopes.is_subset_string(long_scope_string, scope) def from_authorization_code(self, client_id, code, scope): try: found = (OAuthAuthorizationCode .select() .join(OAuthApplication) .where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code, OAuthAuthorizationCode.scope == scope) .get()) logger.debug('Returning data: %s', found.data) return found.data except OAuthAuthorizationCode.DoesNotExist: return None def from_refresh_token(self, client_id, refresh_token, scope): try: found = (OAuthAccessToken .select() .join(OAuthApplication) .where(OAuthApplication.client_id == client_id, OAuthAccessToken.refresh_token == refresh_token, OAuthAccessToken.scope == scope) .get()) return found.data except OAuthAccessToken.DoesNotExist: return None def persist_authorization_code(self, client_id, code, scope): app = OAuthApplication.get(client_id=client_id) data = self._generate_data_string() OAuthAuthorizationCode.create(application=app, code=code, scope=scope, data=data) def persist_token_information(self, client_id, scope, access_token, token_type, expires_in, refresh_token, data): user = get_user(json.loads(data)['username']) if not user: raise RuntimeError('Username must be in the data field') app = OAuthApplication.get(client_id=client_id) expires_at = datetime.utcnow() + timedelta(seconds=expires_in) OAuthAccessToken.create(application=app, authorized_user=user, scope=scope, access_token=access_token, token_type=token_type, expires_at=expires_at, refresh_token=refresh_token, data=data) def discard_authorization_code(self, client_id, code): found = (OAuthAuthorizationCode .select() .join(OAuthApplication) .where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code) .get()) found.delete_instance() def discard_refresh_token(self, client_id, refresh_token): found = (AccessToken .select() .join(OAuthApplication) .where(OAuthApplication.client_id == client_id, OAuthAccessToken.refresh_token == refresh_token) .get()) found.delete_instance() def get_auth_denied_response(self, response_type, client_id, redirect_uri, **params): # Ensure proper response_type if response_type != 'token': err = 'unsupported_response_type' return self._make_redirect_error_response(redirect_uri, err) # Check redirect URI is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri) if not is_valid_redirect_uri: return self._invalid_redirect_uri_response() return self._make_redirect_error_response(redirect_uri, 'authorization_denied') def get_token_response(self, response_type, client_id, redirect_uri, **params): # Ensure proper response_type if response_type != 'token': err = 'unsupported_response_type' return self._make_redirect_error_response(redirect_uri, err) # Check redirect URI is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri) if redirect_uri != 'display' and not is_valid_redirect_uri: return self._invalid_redirect_uri_response() # Check conditions is_valid_client_id = self.validate_client_id(client_id) is_valid_access = self.validate_access() scope = params.get('scope', '') are_valid_scopes = self.validate_scope(client_id, scope) # Return proper error responses on invalid conditions if not is_valid_client_id: err = 'unauthorized_client' return self._make_redirect_error_response(redirect_uri, err) if not is_valid_access: err = 'access_denied' return self._make_redirect_error_response(redirect_uri, err) if not are_valid_scopes: err = 'invalid_scope' return self._make_redirect_error_response(redirect_uri, err) access_token = self.generate_access_token() token_type = self.token_type expires_in = self.token_expires_in refresh_token = None # No refresh token for this kind of flow data = self._generate_data_string() self.persist_token_information(client_id=client_id, scope=scope, access_token=access_token, token_type=token_type, expires_in=expires_in, refresh_token=refresh_token, data=data) url = utils.build_url(redirect_uri, params) url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in) if redirect_uri == 'display': return self._make_response( render_template("message.html", message="Access Token: " + access_token)) return self._make_response(headers={'Location': url}, status_code=302) def create_application(org, name, application_uri, redirect_uri, **kwargs): return OAuthApplication.create(organization=org, name=name, application_uri=application_uri, redirect_uri=redirect_uri, **kwargs) def validate_access_token(access_token): try: found = (OAuthAccessToken .select(OAuthAccessToken, User) .join(User) .where(OAuthAccessToken.access_token == access_token) .get()) return found except OAuthAccessToken.DoesNotExist: return None def get_application_for_client_id(client_id): try: return OAuthApplication.get(client_id=client_id) except OAuthApplication.DoesNotExist: return None def reset_client_secret(application): application.client_secret = random_string_generator(length=40)() application.save() return application def lookup_application(org, client_id): try: return OAuthApplication.get(organization = org, client_id=client_id) except OAuthApplication.DoesNotExist: return None def delete_application(org, client_id): application = lookup_application(org, client_id) if not application: return application.delete_instance(recursive=True, delete_nullable=True) return application def lookup_access_token_for_user(user, token_uuid): try: return OAuthAccessToken.get(OAuthAccessToken.authorized_user == user, OAuthAccessToken.uuid == token_uuid) except OAuthAccessToken.DoesNotExist: return None def list_access_tokens_for_user(user): query = (OAuthAccessToken .select() .join(OAuthApplication) .switch(OAuthAccessToken) .join(User) .where(OAuthAccessToken.authorized_user == user)) return query def list_applications_for_org(org): query = (OAuthApplication .select() .join(User) .where(OAuthApplication.organization == org)) return query def create_access_token_for_testing(user, client_id, scope): expires_at = datetime.utcnow() + timedelta(seconds=10000) application = get_application_for_client_id(client_id) OAuthAccessToken.create(application=application, authorized_user=user, scope=scope, token_type='token', access_token='test', expires_at=expires_at, refresh_token='', data='')