import json
import logging
import multiprocessing
import time

from ctypes import c_bool
from datetime import datetime, timedelta
from threading import Thread
from functools import total_ordering
from enum import Enum, IntEnum
from collections import namedtuple

from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.serialization import load_pem_public_key
from dateutil import parser
from flask import make_response

import jwt


logger = logging.getLogger(__name__)


TRIAL_GRACE_PERIOD = timedelta(days=7)    # 1 week
MONTHLY_GRACE_PERIOD = timedelta(days=335) # 11 months
YEARLY_GRACE_PERIOD = timedelta(days=90)  # 3 months

LICENSE_SOON_DELTA = timedelta(days=7) # 1 week

LICENSE_FILENAME = 'license'

QUAY_ENTITLEMENT = 'software.quay'
QUAY_DEPLOYMENTS_ENTITLEMENT = 'software.quay.deployments'



class LicenseDecodeError(Exception):
  """ Exception raised if the license could not be read, decoded or has expired. """
  pass


def _get_date(decoded, field, default_date=datetime.min):
  """ Retrieves the encoded date found at the given field under the decoded license block. """
  date_str = decoded.get(field)
  return parser.parse(date_str).replace(tzinfo=None) if date_str else default_date


@total_ordering
class Entitlement(object):
  """ An entitlement is a specific piece of software or functionality granted
      by a license. It has an expiration date, as well as the count of the
      things being granted. Entitlements are orderable by their counts.
  """
  def __init__(self, entitlement_name, count, product_name, expiration):
    self.name = entitlement_name
    self.count = count
    self.product_name = product_name
    self.expiration = expiration

  def __lt__(self, rhs):
    return self.count < rhs.count

  def __repr__(self):
    return str(dict(
      name=self.name,
      count=self.count,
      product_name=self.product_name,
      expiration=repr(self.expiration),
    ))

  def as_dict(self, for_private=False):
    data = {
      'name': self.name,
    }

    if for_private:
      data.update({
        'count': self.count,
        'product_name': self.product_name,
        'expiration': self.expiration.as_dict(for_private=True),
      })

    return data

class ExpirationType(Enum):
  """ An enum which represents the different possible types of expirations. If
      you posess an expired enum, you can use this to figure out at what level
      the expiration was most restrictive.
  """
  license_wide = 'License Wide Expiration'
  trial_only = 'Trial Only Expiration'
  in_trial = 'In-Trial Expiration'
  monthly = 'Monthly Subscription Expiration'
  yearly = 'Yearly Subscription Expiration'


@total_ordering
class Expiration(object):
  """ An Expiration is an orderable representation of an expiration date and a
      grace period. If you sort Expiration objects, they will be sorted by the
      actual cutoff date, which is the combination of the expiration date and
      the grace period.
  """
  def __init__(self, expiration_type, exp_date, grace_period=timedelta(seconds=0)):
    self.expiration_type = expiration_type
    self.expiration_date = exp_date
    self.grace_period = grace_period

  @property
  def expires_at(self):
    return self.expiration_date + self.grace_period

  def is_expired(self, now):
    """ Check if the current object should already be considered expired when
        compared with the passed in datetime object.
    """
    return self.expires_at < now

  def __lt__(self, rhs):
    return self.expires_at < rhs.expires_at

  def __repr__(self):
    return str(dict(
      expiration_type=repr(self.expiration_type),
      expiration_date=repr(self.expiration_date),
      grace_period=repr(self.grace_period),
    ))

  def as_dict(self, for_private=False):
    data = {
      'expiration_type': str(self.expiration_type),
    }

    if for_private:
      data.update({
        'expiration_date': str(self.expiration_date),
        'grace_period': str(self.grace_period),
      })

    return data


class EntitlementStatus(IntEnum):
  """ An EntitlementStatus represent the current effectiveness of an
      Entitlement when compared with its corresponding requirement. As an
      example, if the software requires 9 items, and the Entitlement only
      provides for 7, you would use an insufficient_count status.
  """
  met = 0
  expired = 1
  insufficient_count = 2
  no_matching = 3


@total_ordering
class EntitlementValidationResult(object):
  """ An EntitlementValidationResult encodes the combination of a specific
      entitlement and the software requirement which caused it to be examined.
      They are orderable by the value of the EntitlementStatus enum, and will
      in general be sorted by most to least satisfiable status type.
  """
  def __init__(self, requirement, created_at, entitlement=None):
    self.requirement = requirement
    self._created_at = created_at
    self.entitlement = entitlement

  def get_status(self):
    """ Returns the EntitlementStatus when comparing the specified Entitlement
        with the corresponding requirement.
    """
    if self.entitlement is not None:
      if self.entitlement.expiration.is_expired(self._created_at):
        return EntitlementStatus.expired

      if self.entitlement.count < self.requirement.count:
        return EntitlementStatus.insufficient_count

      return EntitlementStatus.met

    return EntitlementStatus.no_matching

  def is_met(self):
    """ Returns whether this specific EntitlementValidationResult meets all
        of the criteria for being sufficient, including unexpired (or in the
        grace period), and with a sufficient count.
    """
    return self.get_status() == EntitlementStatus.met

  def __lt__(self, rhs):
    # If this result has the same status as another, return the result with an expiration date
    # further in the future, as it will be more relevant. The results may expire, but so long as
    # this result is valid, so will the entitlement.
    if self.get_status() == rhs.get_status():
      return (self.entitlement.expiration.expiration_date >
              rhs.entitlement.expiration.expiration_date)

    # Otherwise, sort lexically by status.
    return self.get_status() < rhs.get_status()

  def __repr__(self):
    return str(dict(
      requirement=repr(self.requirement),
      created_at=repr(self._created_at),
      entitlement=repr(self.entitlement),
    ))

  def description(self):
    msg = '%s requires %s: has status %s'
    return msg % (self.requirement.name, self.requirement.count, self.get_status())

  def as_dict(self, for_private=False):
    def req_view():
      return {
        'name': self.requirement.name,
        'count': self.requirement.count,
      }

    data = {
      'requirement': req_view(),
      'status': str(self.get_status()),
    }

    if self.entitlement is not None:
      data['entitlement'] = self.entitlement.as_dict(for_private=for_private)

    return data


class License(object):
  """ License represents a fully decoded and validated (but potentially expired) license. """
  def __init__(self, decoded):
    self.decoded = decoded

  def validate_entitlement_requirement(self, entitlement_req, check_time):
    all_active_entitlements = list(self._find_entitlements(entitlement_req.name))

    if len(all_active_entitlements) == 0:
      return EntitlementValidationResult(entitlement_req, check_time)

    entitlement_results = [EntitlementValidationResult(entitlement_req, check_time, ent)
                           for ent in all_active_entitlements]
    entitlement_results.sort()
    return entitlement_results[0]

  def _find_entitlements(self, entitlement_name):
    license_expiration = Expiration(
      ExpirationType.license_wide,
      _get_date(self.decoded, 'expirationDate'),
    )

    for sub in self.decoded.get('subscriptions', {}).values():
      entitlement_count = sub.get('entitlements', {}).get(entitlement_name)

      if entitlement_count is not None:
        entitlement_expiration = min(self._sub_expiration(sub), license_expiration)
        yield Entitlement(
          entitlement_name,
          entitlement_count,
          sub.get('productName', 'unknown'),
          entitlement_expiration,
        )

  @staticmethod
  def _sub_expiration(subscription):
    # A trial license has its own end logic, and uses the trialEnd property
    if subscription.get('trialOnly', False):
      trial_expiration = Expiration(
        ExpirationType.trial_only,
        _get_date(subscription, 'trialEnd'),
        TRIAL_GRACE_PERIOD,
      )
      return trial_expiration

    # From here we always use the serviceEnd
    service_end = _get_date(subscription, 'serviceEnd')

    if subscription.get('inTrial', False):
      return Expiration(ExpirationType.in_trial, service_end, TRIAL_GRACE_PERIOD)

    if subscription.get('durationPeriod') == 'yearly':
      return Expiration(ExpirationType.yearly, service_end, YEARLY_GRACE_PERIOD)

    # We assume monthly license unless specified otherwise
    return Expiration(ExpirationType.monthly, service_end, MONTHLY_GRACE_PERIOD)

  def validate(self, config):
    """ Returns a list of EntitlementValidationResult objects, one per requirement.
    """
    requirements = _gen_entitlement_requirements(config)
    now = datetime.now()
    return [self.validate_entitlement_requirement(req, now) for req in requirements]


_PROD_LICENSE_PUBLIC_KEY_DATA = """
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAuCkRnkuqox3A0djgRnHR
e3U3jHrcbd5iUqdbfO/8E2TMbiByIy3NzUyJrMIzrTjdxTVIZF/ueaHLEtgaofUA
1X73OZlsaGyNVDFA2eGZRgyNrmfLFoxnN2KB+gEJ88nPkHZXY+4ncZBjVMKfHQEv
busC7xpnF7Diy2GxZKDZRnvjL4ZNrocdoeE0GuroWwebtck5Ea7LqzRxCJ5T3UWt
EozttOBQAqCmKxSDdtdw+CsK/uTfl6Yh9xCZUrCeh5taSOHOvU0ne/p3gM+AsjU4
ScjObTKaSUOGen6aYFF5Bd6V/ucxHmcmJlycwNZOKGFpbhLU173/oBJ+okvDbJpN
qwIDAQAB
-----END PUBLIC KEY-----
"""
_PROD_LICENSE_PUBLIC_KEY = load_pem_public_key(_PROD_LICENSE_PUBLIC_KEY_DATA,
                                               backend=default_backend())

def decode_license(license_contents, public_key_instance=None):
  """ Decodes the specified license contents, returning the decoded license. """
  license_public_key = public_key_instance or _PROD_LICENSE_PUBLIC_KEY
  try:
    jwt_data = jwt.decode(license_contents, key=license_public_key)
  except jwt.exceptions.DecodeError as de:
    logger.exception('Could not decode license file')
    raise LicenseDecodeError('Could not decode license found: %s' % de.message)

  try:
    decoded = json.loads(jwt_data.get('license', '{}'))
  except ValueError as ve:
    logger.exception('Could not decode license file')
    raise LicenseDecodeError('Could not decode license found: %s' % ve.message)

  return License(decoded)


LICENSE_VALIDATION_INTERVAL = 3600 # seconds
LICENSE_VALIDATION_EXPIRED_INTERVAL = 60 # seconds


EntitlementRequirement = namedtuple('EntitlementRequirements', ['name', 'count'])


def _gen_entitlement_requirements(config_obj):
  config_regions = len(config_obj.get('DISTRIBUTED_STORAGE_CONFIG', []))
  return [
    EntitlementRequirement(QUAY_ENTITLEMENT, 1),
    EntitlementRequirement(QUAY_DEPLOYMENTS_ENTITLEMENT, config_regions),
  ]


class LicenseValidator(Thread):
  """
  LicenseValidator is a thread that asynchronously reloads and validates license files.

  This thread is meant to be run before registry gunicorn workers fork and uses shared memory as a
  synchronization primitive.
  """
  def __init__(self, config_provider, *args, **kwargs):
    config = config_provider.get_config() or {}

    self._config_provider = config_provider
    self._entitlement_requirements = _gen_entitlement_requirements(config)

    # multiprocessing.Value does not ensure consistent write-after-reads, but we don't need that.
    self._license_is_insufficient = multiprocessing.Value(c_bool, True)
    self._license_expiring_soon = multiprocessing.Value(c_bool, True)

    super(LicenseValidator, self).__init__(*args, **kwargs)
    self.daemon = True

  @property
  def expiring_soon(self):
    """ Returns whether the license will be expiring soon (a week from now). """
    return self._license_expiring_soon.value

  @property
  def insufficient(self):
    return self._license_is_insufficient.value

  def compute_license_sufficiency(self):
    """ Check whether all of our requirements are met, and set the status of
        the result of the check, which will be used to disable the software.
        Returns True if any requirements are not met, and False if all are met.
    """
    try:
      current_license = self._config_provider.get_license()
      now = datetime.now()
      soon = now + LICENSE_SOON_DELTA
      any_invalid = not all(current_license.validate_entitlement_requirement(req, now).is_met()
                            for req in self._entitlement_requirements)
      soon_invalid = not all(current_license.validate_entitlement_requirement(req, soon).is_met()
                             for req in self._entitlement_requirements)
      logger.debug('updating license license_is_insufficient to %s', any_invalid)
      logger.debug('updating license license_expiring_soon to %s', soon_invalid)
    except (IOError, LicenseDecodeError):
      logger.exception('failed to validate license')
      any_invalid = True
      soon_invalid = False

    self._license_is_insufficient.value = any_invalid
    self._license_expiring_soon.value = soon_invalid
    return any_invalid

  def run(self):
    logger.debug('Starting license validation thread')
    while True:
      invalid = self.compute_license_sufficiency()
      sleep_time = LICENSE_VALIDATION_EXPIRED_INTERVAL if invalid else LICENSE_VALIDATION_INTERVAL
      logger.debug('waiting %d seconds before retrying to validate license', sleep_time)
      time.sleep(sleep_time)

  def enforce_license_before_request(self, blueprint, response_func=None):
    """
    Adds a pre-check to a Flask blueprint such that if the provided license_validator determines the
    license has become invalid, the client will receive a HTTP 402 response.
    """
    if response_func is None:
      def _response_func():
        return make_response('License is insufficient.', 402)
      response_func = _response_func

    def _enforce_license():
      if self.insufficient:
        logger.debug('blocked interaction due to insufficient license')
        return response_func()
    blueprint.before_request(_enforce_license)