from cachetools import lru_cache
from data import model
from util.expiresdict import ExpiresDict, ExpiresEntry
from util.security import jwtutil


class InstanceKeys(object):
  """ InstanceKeys defines a helper class for interacting with the Quay instance service keys
      used for JWT signing of registry tokens as well as requests from Quay to other services
      such as Clair. Each container will have a single registered instance key.
  """
  def __init__(self, app):
    self.app = app
    self.instance_keys = ExpiresDict(self._load_instance_keys)
    self.public_keys = {}

  def clear_cache(self):
    """ Clears the cache of instance keys. """
    self.instance_keys = ExpiresDict(self._load_instance_keys)
    self.public_keys = {}

  def _load_instance_keys(self):
    # Load all the instance keys.
    keys = {}
    for key in model.service_keys.list_service_keys(self.service_name):
      keys[key.kid] = ExpiresEntry(key, key.expiration_date)

    # Remove any expired or deleted keys from the public keys cache.
    for key in dict(self.public_keys):
      if key not in keys:
        self.public_keys.pop(key)

    return keys

  @property
  def service_name(self):
    """ Returns the name of the instance key's service (i.e. 'quay'). """
    return self.app.config['INSTANCE_SERVICE_KEY_SERVICE']

  @property
  def service_key_expiration(self):
    """ Returns the defined expiration for instance service keys, in minutes. """
    return self.app.config.get('INSTANCE_SERVICE_KEY_EXPIRATION', 120)

  @property
  @lru_cache(maxsize=1)
  def local_key_id(self):
    """ Returns the ID of the local instance service key. """
    return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_KID_LOCATION'])

  @property
  @lru_cache(maxsize=1)
  def local_private_key(self):
    """ Returns the private key of the local instance service key. """
    return _load_file_contents(self.app.config['INSTANCE_SERVICE_KEY_LOCATION'])

  def get_service_key_public_key(self, kid):
    """ Returns the public key associated with the given instance service key or None if none. """

    # Note: We do the lookup via instance_keys *first* to ensure that if a key has expired, we
    # don't use the entry in the public key cache.
    service_key = self.instance_keys.get(kid)
    if service_key is None:
      # Remove the kid from the cache just to be sure.
      self.public_keys.pop(kid, None)
      return None

    public_key = self.public_keys.get(kid)
    if public_key is not None:
      return public_key

    # Convert the JWK into a public key and cache it (since the conversion can take > 200ms).
    public_key = jwtutil.jwk_dict_to_public_key(service_key.jwk)
    self.public_keys[kid] = public_key
    return public_key


def _load_file_contents(path):
  """ Returns the contents of the specified file path. """
  with open(path) as f:
    return f.read()