import logging
import json

from datetime import datetime, timedelta
from oauth2lib.provider import AuthorizationProvider
from oauth2lib import utils

from data.database import (OAuthApplication, OAuthAuthorizationCode, OAuthAccessToken, User,
                           random_string_generator)
from data.model.legacy import get_user
from auth import scopes


logger = logging.getLogger(__name__)


class DatabaseAuthorizationProvider(AuthorizationProvider):
  def get_authorized_user(self):
    raise NotImplementedError('Subclasses must fill in the ability to get the authorized_user.')

  def _generate_data_string(self):
    return json.dumps({'username': self.get_authorized_user().username})

  @property
  def token_expires_in(self):
    """Property method to get the token expiration time in seconds.
    """
    return int(60*60*24*365.25*10)  # 10 Years

  def validate_client_id(self, client_id):
    return self.get_application_for_client_id(client_id) is not None

  def get_application_for_client_id(self, client_id):
    try:
      return OAuthApplication.get(client_id=client_id)
    except OAuthApplication.DoesNotExist:
      return None

  def validate_client_secret(self, client_id, client_secret):
    try:
      OAuthApplication.get(client_id=client_id, client_secret=client_secret)
      return True
    except OAuthApplication.DoesNotExist:
      return False

  def validate_redirect_uri(self, client_id, redirect_uri):
    try:
      app = OAuthApplication.get(client_id=client_id)
      if app.redirect_uri and redirect_uri.startswith(app.redirect_uri):
        return True
      return False
    except OAuthApplication.DoesNotExist:
      return False

  def validate_scope(self, client_id, scopes_string):
    return scopes.validate_scope_string(scopes_string)

  def validate_access(self):
    return self.get_authorized_user() is not None

  def load_authorized_scope_string(self, client_id, username):
    found = (OAuthAccessToken
      .select()
      .join(OAuthApplication)
      .switch(OAuthAccessToken)
      .join(User)
      .where(OAuthApplication.client_id == client_id, User.username == username,
             OAuthAccessToken.expires_at > datetime.now()))
    found = list(found)
    logger.debug('Found %s matching tokens.', len(found))
    long_scope_string = ','.join([token.scope for token in found])
    logger.debug('Computed long scope string: %s', long_scope_string)
    return long_scope_string

  def validate_has_scopes(self, client_id, username, scope):
    long_scope_string = self.load_authorized_scope_string(client_id, username)

    # Make sure the token contains the given scopes (at least).
    return scopes.is_subset_string(long_scope_string, scope)

  def from_authorization_code(self, client_id, code, scope):
    try:
      found = (OAuthAuthorizationCode
        .select()
        .join(OAuthApplication)
        .where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code,
               OAuthAuthorizationCode.scope == scope)
        .get())
      logger.debug('Returning data: %s', found.data)
      return found.data
    except OAuthAuthorizationCode.DoesNotExist:
      return None

  def from_refresh_token(self, client_id, refresh_token, scope):
    try:
      found = (OAuthAccessToken
        .select()
        .join(OAuthApplication)
        .where(OAuthApplication.client_id == client_id,
               OAuthAccessToken.refresh_token == refresh_token,
               OAuthAccessToken.scope == scope)
        .get())
      return found.data
    except OAuthAccessToken.DoesNotExist:
      return None  

  def persist_authorization_code(self, client_id, code, scope):
    app = OAuthApplication.get(client_id=client_id)
    data = self._generate_data_string()
    OAuthAuthorizationCode.create(application=app, code=code, scope=scope, data=data)

  def persist_token_information(self, client_id, scope, access_token, token_type, expires_in,
                                refresh_token, data):
    user = get_user(json.loads(data)['username'])
    if not user:
      raise RuntimeError('Username must be in the data field')

    app = OAuthApplication.get(client_id=client_id)
    expires_at = datetime.now() + timedelta(seconds=expires_in)
    OAuthAccessToken.create(application=app, authorized_user=user, scope=scope,
                            access_token=access_token, token_type=token_type,
                            expires_at=expires_at, refresh_token=refresh_token, data=data)

  def discard_authorization_code(self, client_id, code):
    found = (OAuthAuthorizationCode
      .select()
      .join(OAuthApplication)
      .where(OAuthApplication.client_id == client_id, OAuthAuthorizationCode.code == code)
      .get())
    found.delete_instance()

  def discard_refresh_token(self, client_id, refresh_token):
    found = (AccessToken
      .select()
      .join(OAuthApplication)
      .where(OAuthApplication.client_id == client_id,
             OAuthAccessToken.refresh_token == refresh_token)
      .get())
    found.delete_instance()


  def get_auth_denied_response(self, response_type, client_id, redirect_uri, **params):
    # Ensure proper response_type
    if response_type != 'token':
      err = 'unsupported_response_type'
      return self._make_redirect_error_response(redirect_uri, err)

    # Check redirect URI
    is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
    if not is_valid_redirect_uri:
      return self._invalid_redirect_uri_response()

    return self._make_redirect_error_response(redirect_uri, 'authorization_denied')


  def get_token_response(self, response_type, client_id, redirect_uri, **params):
    # Ensure proper response_type
    if response_type != 'token':
      err = 'unsupported_response_type'
      return self._make_redirect_error_response(redirect_uri, err)

    # Check redirect URI
    is_valid_redirect_uri = self.validate_redirect_uri(client_id, redirect_uri)
    if not is_valid_redirect_uri:
      return self._invalid_redirect_uri_response()

    # Check conditions
    is_valid_client_id = self.validate_client_id(client_id)
    is_valid_access = self.validate_access()
    scope = params.get('scope', '')
    are_valid_scopes = self.validate_scope(client_id, scope)

    # Return proper error responses on invalid conditions
    if not is_valid_client_id:
      err = 'unauthorized_client'
      return self._make_redirect_error_response(redirect_uri, err)

    if not is_valid_access:
      err = 'access_denied'
      return self._make_redirect_error_response(redirect_uri, err)

    if not are_valid_scopes:
      err = 'invalid_scope'
      return self._make_redirect_error_response(redirect_uri, err)

    access_token = self.generate_access_token()
    token_type = self.token_type
    expires_in = self.token_expires_in
    refresh_token = None  # No refresh token for this kind of flow

    data = self._generate_data_string()
    self.persist_token_information(client_id=client_id, scope=scope, access_token=access_token,
                                   token_type=token_type, expires_in=expires_in,
                                   refresh_token=refresh_token, data=data)

    url = utils.build_url(redirect_uri, params)
    url += '#access_token=%s&token_type=%s&expires_in=%s' % (access_token, token_type, expires_in)

    return self._make_response(headers={'Location': url}, status_code=302)


def create_application(org, name, application_uri, redirect_uri, **kwargs):
  return OAuthApplication.create(organization=org, name=name, application_uri=application_uri,
                                 redirect_uri=redirect_uri, **kwargs)


def validate_access_token(access_token):
  try:
    found = (OAuthAccessToken
      .select(OAuthAccessToken, User)
      .join(User)
      .where(OAuthAccessToken.access_token == access_token)
      .get())
    return found
  except OAuthAccessToken.DoesNotExist:
    return None


def get_application_for_client_id(client_id):
  try:
    return OAuthApplication.get(client_id=client_id)
  except OAuthApplication.DoesNotExist:
    return None


def reset_client_secret(application):
  application.client_secret = random_string_generator(length=40)()
  application.save()
  return application


def lookup_application(org, client_id):
  try:
    return OAuthApplication.get(organization = org, client_id=client_id)
  except OAuthApplication.DoesNotExist:
    return None


def delete_application(org, client_id):
  application = lookup_application(org, client_id)
  if not application:
    return

  application.delete_instance(recursive=True, delete_nullable=True)
  return application


def lookup_access_token_for_user(user, token_uuid):
  try:
    return OAuthAccessToken.get(OAuthAccessToken.authorized_user == user,
                                OAuthAccessToken.uuid == token_uuid)
  except OAuthAccessToken.DoesNotExist:
    return None


def list_access_tokens_for_user(user):
  query = (OAuthAccessToken
           .select()
           .join(OAuthApplication)
           .switch(OAuthAccessToken)
           .join(User)
           .where(OAuthAccessToken.authorized_user == user))

  return query  


def list_applications_for_org(org):
  query = (OAuthApplication
           .select()
           .join(User)
           .where(OAuthApplication.organization == org))

  return query


def create_access_token_for_testing(user, client_id, scope):
  expires_at = datetime.now() + timedelta(seconds=10000)
  application = get_application_for_client_id(client_id)
  OAuthAccessToken.create(application=application, authorized_user=user, scope=scope,
                          token_type='token', access_token='test',
                          expires_at=expires_at, refresh_token='', data='')