import urlparse
import json
import logging
import time

from cachetools import TTLCache
from jwkest.jwk import KEYS
from util import slash_join

logger = logging.getLogger(__name__)

class OAuthConfig(object):
  def __init__(self, config, key_name):
    self.key_name = key_name
    self.config = config.get(key_name) or {}

  def service_name(self):
    raise NotImplementedError

  def token_endpoint(self):
    raise NotImplementedError

  def user_endpoint(self):
    raise NotImplementedError

  def validate_client_id_and_secret(self, http_client, app_config):
    raise NotImplementedError

  def client_id(self):
    return self.config.get('CLIENT_ID')

  def client_secret(self):
    return self.config.get('CLIENT_SECRET')

  def get_redirect_uri(self, app_config, redirect_suffix=''):
    return '%s://%s/oauth2/%s/callback%s' % (app_config['PREFERRED_URL_SCHEME'],
                                             app_config['SERVER_HOSTNAME'],
                                             self.service_name().lower(),
                                             redirect_suffix)


  def exchange_code_for_token(self, app_config, http_client, code, form_encode=False,
                              redirect_suffix='', client_auth=False):
    payload = {
      'code': code,
      'grant_type': 'authorization_code',
      'redirect_uri': self.get_redirect_uri(app_config, redirect_suffix)
    }

    headers = {
      'Accept': 'application/json'
    }

    auth = None
    if client_auth:
      auth = (self.client_id(), self.client_secret())
    else:
      payload['client_id'] = self.client_id()
      payload['client_secret'] = self.client_secret()

    token_url = self.token_endpoint()
    if form_encode:
      get_access_token = http_client.post(token_url, data=payload, headers=headers, auth=auth)
    else:
      get_access_token = http_client.post(token_url, params=payload, headers=headers, auth=auth)

    json_data = get_access_token.json()
    if not json_data:
      return ''

    token = json_data.get('access_token', '')
    return token


class GithubOAuthConfig(OAuthConfig):
  def __init__(self, config, key_name):
    super(GithubOAuthConfig, self).__init__(config, key_name)

  def service_name(self):
    return 'GitHub'

  def allowed_organizations(self):
    if not self.config.get('ORG_RESTRICT', False):
      return None

    allowed = self.config.get('ALLOWED_ORGANIZATIONS', None)
    if allowed is None:
      return None

    return [org.lower() for org in allowed]

  def get_public_url(self, suffix):
    return slash_join(self._endpoint(), suffix)

  def _endpoint(self):
    return self.config.get('GITHUB_ENDPOINT', 'https://github.com')

  def is_enterprise(self):
    return self._endpoint().find('.github.com') < 0

  def authorize_endpoint(self):
    return slash_join(self._endpoint(), '/login/oauth/authorize')  + '?'

  def token_endpoint(self):
    return slash_join(self._endpoint(), '/login/oauth/access_token')

  def _api_endpoint(self):
    return self.config.get('API_ENDPOINT', slash_join(self._endpoint(), '/api/v3/'))

  def api_endpoint(self):
    endpoint = self._api_endpoint()
    if endpoint.endswith('/'):
      return endpoint[0:-1]

    return endpoint

  def user_endpoint(self):
    return slash_join(self._api_endpoint(), 'user')

  def email_endpoint(self):
    return slash_join(self._api_endpoint(), 'user/emails')

  def orgs_endpoint(self):
    return slash_join(self._api_endpoint(), 'user/orgs')

  def validate_client_id_and_secret(self, http_client, app_config):
    # First: Verify that the github endpoint is actually Github by checking for the
    # X-GitHub-Request-Id here.
    api_endpoint = self._api_endpoint()
    result = http_client.get(api_endpoint, auth=(self.client_id(), self.client_secret()), timeout=5)
    if not 'X-GitHub-Request-Id' in result.headers:
      raise Exception('Endpoint is not a Github (Enterprise) installation')

    # Next: Verify the client ID and secret.
    # Note: The following code is a hack until such time as Github officially adds an API endpoint
    # for verifying a {client_id, client_secret} pair. That being said, this hack was given to us
    # *by a Github Engineer*, so I think it is okay for the time being :)
    #
    # TODO(jschorr): Replace with the real API call once added.
    #
    # Hitting the endpoint applications/{client_id}/tokens/foo will result in the following
    # behavior IF the client_id is given as the HTTP username and the client_secret as the HTTP
    # password:
    #   - If the {client_id, client_secret} pair is invalid in some way, we get a 401 error.
    #   - If the pair is valid, then we get a 404 because the 'foo' token does not exists.
    validate_endpoint = slash_join(api_endpoint, 'applications/%s/tokens/foo' % self.client_id())
    result = http_client.get(validate_endpoint, auth=(self.client_id(), self.client_secret()),
                             timeout=5)
    return result.status_code == 404

  def validate_organization(self, organization_id, http_client):
    org_endpoint = slash_join(self._api_endpoint(), 'orgs/%s' % organization_id.lower())

    result = http_client.get(org_endpoint,
                             headers={'Accept': 'application/vnd.github.moondragon+json'},
                             timeout=5)

    return result.status_code == 200


  def get_public_config(self):
    return  {
      'CLIENT_ID': self.client_id(),
      'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
      'GITHUB_ENDPOINT': self._endpoint(),
      'ORG_RESTRICT': self.config.get('ORG_RESTRICT', False)
    }



class GoogleOAuthConfig(OAuthConfig):
  def __init__(self, config, key_name):
    super(GoogleOAuthConfig, self).__init__(config, key_name)

  def service_name(self):
    return 'Google'

  def authorize_endpoint(self):
    return 'https://accounts.google.com/o/oauth2/auth?response_type=code&'

  def token_endpoint(self):
    return 'https://accounts.google.com/o/oauth2/token'

  def user_endpoint(self):
    return 'https://www.googleapis.com/oauth2/v1/userinfo'

  def validate_client_id_and_secret(self, http_client, app_config):
    # To verify the Google client ID and secret, we hit the
    # https://www.googleapis.com/oauth2/v3/token endpoint with an invalid request. If the client
    # ID or secret are invalid, we get returned a 403 Unauthorized. Otherwise, we get returned
    # another response code.
    url = 'https://www.googleapis.com/oauth2/v3/token'
    data = {
      'code': 'fakecode',
      'client_id': self.client_id(),
      'client_secret': self.client_secret(),
      'grant_type': 'authorization_code',
      'redirect_uri': 'http://example.com'
    }

    result = http_client.post(url, data=data, timeout=5)
    return result.status_code != 401

  def get_public_config(self):
    return  {
      'CLIENT_ID': self.client_id(),
      'AUTHORIZE_ENDPOINT': self.authorize_endpoint()
    }


class GitLabOAuthConfig(OAuthConfig):
  def __init__(self, config, key_name):
    super(GitLabOAuthConfig, self).__init__(config, key_name)

  def _endpoint(self):
    return self.config.get('GITLAB_ENDPOINT', 'https://gitlab.com')

  def api_endpoint(self):
    return self._endpoint()

  def get_public_url(self, suffix):
    return slash_join(self._endpoint(), suffix)

  def service_name(self):
    return 'GitLab'

  def authorize_endpoint(self):
    return slash_join(self._endpoint(), '/oauth/authorize')

  def token_endpoint(self):
    return slash_join(self._endpoint(), '/oauth/token')

  def validate_client_id_and_secret(self, http_client, app_config):
    url = self.token_endpoint()
    redirect_uri = self.get_redirect_uri(app_config, redirect_suffix='trigger')
    data = {
      'code': 'fakecode',
      'client_id': self.client_id(),
      'client_secret': self.client_secret(),
      'grant_type': 'authorization_code',
      'redirect_uri': redirect_uri
    }

    # We validate by checking the error code we receive from this call.
    result = http_client.post(url, data=data, timeout=5)
    value = result.json()
    if not value:
      return False

    return value.get('error', '') != 'invalid_client'

  def get_public_config(self):
    return {
      'CLIENT_ID': self.client_id(),
      'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),
      'GITLAB_ENDPOINT': self._endpoint(),
    }


OIDC_WELLKNOWN = ".well-known/openid-configuration"
PUBLIC_KEY_CACHE_TTL = 3600 # 1 hour

class OIDCConfig(OAuthConfig):
  def __init__(self, config, key_name):
    super(OIDCConfig, self).__init__(config, key_name)

    self._public_key_cache = TTLCache(1, PUBLIC_KEY_CACHE_TTL, missing=self._get_public_key)
    self._oidc_config = {}
    self._http_client = config['HTTPCLIENT']

    if self.config.get('OIDC_SERVER'):
      self._load_via_discovery(config.get('DEBUGGING', False))

  def _load_via_discovery(self, is_debugging):
    oidc_server = self.config['OIDC_SERVER']
    if not oidc_server.startswith('https://') and not is_debugging:
      raise Exception('OIDC server must be accessed over SSL')

    discovery_url = urlparse.urljoin(oidc_server, OIDC_WELLKNOWN)
    discovery = self._http_client.get(discovery_url, timeout=5)

    if discovery.status_code / 100 != 2:
      raise Exception("Could not load OIDC discovery information")

    try:
      self._oidc_config = json.loads(discovery.text)
    except ValueError:
      logger.exception('Could not parse OIDC discovery for url: %s', discovery_url)
      raise Exception("Could not parse OIDC discovery information")

  def authorize_endpoint(self):
    return self._oidc_config.get('authorization_endpoint', '') + '?'

  def token_endpoint(self):
    return self._oidc_config.get('token_endpoint')

  def user_endpoint(self):
    return None

  def validate_client_id_and_secret(self, http_client, app_config):
    pass

  def get_public_config(self):
    return {
      'CLIENT_ID': self.client_id(),
      'AUTHORIZE_ENDPOINT': self.authorize_endpoint()
    }

  @property
  def issuer(self):
    return self.config.get('OIDC_ISSUER', self.config['OIDC_SERVER'])

  def get_public_key(self, force_refresh=False):
    """ Retrieves the public key for this handler. """
    # If force_refresh is true, we expire all the items in the cache by setting the time to
    # the current time + the expiration TTL.
    if force_refresh:
      self._public_key_cache.expire(time=time.time() + PUBLIC_KEY_CACHE_TTL)

    # Retrieve the public key from the cache. If the cache does not contain the public key,
    # it will internally call _get_public_key to retrieve it and then save it. The None is
    # a random key chose to be stored in the cache, and could be anything.
    return self._public_key_cache[None]

  def _get_public_key(self):
    """ Retrieves the public key for this handler. """
    keys_url = self._oidc_config['jwks_uri']

    keys = KEYS()
    keys.load_from_url(keys_url)

    if not list(keys):
      raise Exception('No keys provided by OIDC provider')

    rsa_key = list(keys)[0]
    rsa_key.deserialize()
    return rsa_key.key.exportKey('PEM')


class DexOAuthConfig(OIDCConfig):
  def service_name(self):
    return 'Dex'

  @property
  def public_title(self):
    return self.get_public_config()['OIDC_TITLE']

  def get_public_config(self):
    return {
      'CLIENT_ID': self.client_id(),
      'AUTHORIZE_ENDPOINT': self.authorize_endpoint(),

      # TODO(jschorr): This should ideally come from the Dex side.
      'OIDC_TITLE': 'Dex',
      'OIDC_LOGO': 'https://tectonic.com/assets/ico/favicon-96x96.png'
    }