import logging

from uuid import uuid4

from peewee import IntegrityError, JOIN_LEFT_OUTER, fn
from data.model import (image, db_transaction, DataModelException, _basequery,
                        InvalidManifestException, TagAlreadyCreatedException, StaleTagException)
from data.database import (RepositoryTag, Repository, Image, ImageStorage, Namespace, TagManifest,
                           RepositoryNotification, Label, TagManifestLabel, get_epoch_timestamp,
                           db_for_update)


logger = logging.getLogger(__name__)


def get_max_id_for_sec_scan():
  """ Gets the maximum id for security scanning """
  return RepositoryTag.select(fn.Max(RepositoryTag.id)).scalar()


def get_min_id_for_sec_scan(version):
  """ Gets the minimum id for a security scanning """
  return _tag_alive(RepositoryTag
                    .select(fn.Min(RepositoryTag.id))
                    .join(Image)
                    .where(Image.security_indexed_engine < version)).scalar()


def get_tag_pk_field():
  """ Returns the primary key for Image DB model """
  return RepositoryTag.id


def get_tags_images_eligible_for_scan(clair_version):
  Parent = Image.alias()
  ParentImageStorage = ImageStorage.alias()

  return _tag_alive(RepositoryTag
                    .select(Image, ImageStorage, Parent, ParentImageStorage, RepositoryTag)
                    .join(Image, on=(RepositoryTag.image == Image.id))
                    .join(ImageStorage, on=(Image.storage == ImageStorage.id))
                    .switch(Image)
                    .join(Parent, JOIN_LEFT_OUTER, on=(Image.parent == Parent.id))
                    .join(ParentImageStorage, JOIN_LEFT_OUTER, on=(ParentImageStorage.id == Parent.storage))
                    .where(RepositoryTag.hidden == False)
                    .where(Image.security_indexed_engine < clair_version))


def _tag_alive(query, now_ts=None):
  if now_ts is None:
    now_ts = get_epoch_timestamp()
  return query.where((RepositoryTag.lifetime_end_ts >> None) |
                     (RepositoryTag.lifetime_end_ts > now_ts))


def filter_has_repository_event(query, event):
  """ Filters the query by ensuring the repositories returned have the given event. """
  return (query
          .join(Repository)
          .join(RepositoryNotification)
          .where(RepositoryNotification.event == event))


def filter_tags_have_repository_event(query, event):
  """ Filters the query by ensuring the repository tags live in a repository that has the given
      event. Also returns the image storage for the tag's image and orders the results by
      lifetime_start_ts.
  """
  query = filter_has_repository_event(query, event)
  query = query.switch(Image).join(ImageStorage)
  query = query.switch(RepositoryTag).order_by(RepositoryTag.lifetime_start_ts.desc())
  return query


_MAX_SUB_QUERIES = 100
_MAX_IMAGE_LOOKUP_COUNT = 500

def get_matching_tags_for_images(image_pairs, filter_images=None, filter_tags=None,
                                 selections=None):
  """ Returns all tags that contain the images with the given docker_image_id and storage_uuid,
      as specified as an iterable of pairs. """
  if not image_pairs:
    return []

  image_pairs_set = set(image_pairs)

  # Find all possible matching image+storages.
  images = []

  while image_pairs:
    image_pairs_slice = image_pairs[:_MAX_IMAGE_LOOKUP_COUNT]

    ids = [pair[0] for pair in image_pairs_slice]
    uuids = [pair[1] for pair in image_pairs_slice]

    images_query = (Image
                    .select(Image.id, Image.docker_image_id, Image.ancestors, ImageStorage.uuid)
                    .join(ImageStorage)
                    .where(Image.docker_image_id << ids, ImageStorage.uuid << uuids)
                    .switch(Image))

    if filter_images is not None:
      images_query = filter_images(images_query)

    images.extend(list(images_query))
    image_pairs = image_pairs[_MAX_IMAGE_LOOKUP_COUNT:]

  # Filter down to those images actually in the pairs set and build the set of queries to run.
  individual_image_queries = []

  for img in images:
    # Make sure the image found is in the set of those requested, and that we haven't already
    # processed it. We need this check because the query above checks for images with matching
    # IDs OR storage UUIDs, rather than the expected ID+UUID pair. We do this for efficiency
    # reasons, and it is highly unlikely we'll find an image with a mismatch, but we need this
    # check to be absolutely sure.
    pair = (img.docker_image_id, img.storage.uuid)
    if pair not in image_pairs_set:
      continue

    # Remove the pair so we don't try it again.
    image_pairs_set.remove(pair)

    ancestors_str = '%s%s/%%' % (img.ancestors, img.id)
    query = (Image
             .select(Image.id)
             .where((Image.id == img.id) | (Image.ancestors ** ancestors_str)))

    individual_image_queries.append(query)

  if not individual_image_queries:
    return []

  # Shard based on the max subquery count. This is used to prevent going over the DB's max query
  # size, as well as to prevent the DB from locking up on a massive query.
  sharded_queries = []
  while individual_image_queries:
    shard = individual_image_queries[:_MAX_SUB_QUERIES]
    sharded_queries.append(_basequery.reduce_as_tree(shard))
    individual_image_queries = individual_image_queries[_MAX_SUB_QUERIES:]

  # Collect IDs of the tags found for each query.
  tags = {}
  for query in sharded_queries:
    tag_query = (_tag_alive(RepositoryTag
                            .select(*(selections or []))
                            .distinct()
                            .join(Image)
                            .where(RepositoryTag.hidden == False)
                            .where(Image.id << query)
                            .switch(RepositoryTag)))

    if filter_tags is not None:
      tag_query = filter_tags(tag_query)

    for tag in tag_query:
      tags[tag.id] = tag

  return tags.values()


def get_matching_tags(docker_image_id, storage_uuid, *args):
  """ Returns a query pointing to all tags that contain the image with the
      given docker_image_id and storage_uuid. """
  image_row = image.get_image_with_storage(docker_image_id, storage_uuid)
  if image_row is None:
    return RepositoryTag.select().where(RepositoryTag.id < 0) # Empty query.

  ancestors_str = '%s%s/%%' % (image_row.ancestors, image_row.id)
  return _tag_alive(RepositoryTag
                    .select(*args)
                    .distinct()
                    .join(Image)
                    .join(ImageStorage)
                    .where(RepositoryTag.hidden == False)
                    .where((Image.id == image_row.id) |
                           (Image.ancestors ** ancestors_str)))


def get_tags_for_image(image_id, *args):
  return _tag_alive(RepositoryTag
                    .select(*args)
                    .distinct()
                    .where(RepositoryTag.image == image_id,
                           RepositoryTag.hidden == False))


def get_tag_manifest_digests(tags):
  """ Returns a map from tag ID to its associated manifest digest, if any. """
  if not tags:
    return dict()

  manifests = (TagManifest
               .select(TagManifest.tag, TagManifest.digest)
               .where(TagManifest.tag << [t.id for t in tags]))

  return {manifest.tag_id: manifest.digest for manifest in manifests}


def list_active_repo_tags(repo):
  """ Returns all of the active, non-hidden tags in a repository, joined to they images
      and (if present), their manifest.
  """
  query = _tag_alive(RepositoryTag
                     .select(RepositoryTag, Image, TagManifest.digest)
                     .join(Image)
                     .where(RepositoryTag.repository == repo, RepositoryTag.hidden == False)
                     .switch(RepositoryTag)
                     .join(TagManifest, JOIN_LEFT_OUTER))

  return query


def list_repository_tags(namespace_name, repository_name, include_hidden=False,
                         include_storage=False):
  to_select = (RepositoryTag, Image)
  if include_storage:
    to_select = (RepositoryTag, Image, ImageStorage)

  query = _tag_alive(RepositoryTag
                     .select(*to_select)
                     .join(Repository)
                     .join(Namespace, on=(Repository.namespace_user == Namespace.id))
                     .switch(RepositoryTag)
                     .join(Image)
                     .where(Repository.name == repository_name,
                            Namespace.username == namespace_name))

  if not include_hidden:
    query = query.where(RepositoryTag.hidden == False)

  if include_storage:
    query = query.switch(Image).join(ImageStorage)

  return query


def create_or_update_tag(namespace_name, repository_name, tag_name, tag_docker_image_id,
                         reversion=False):
  try:
    repo = _basequery.get_existing_repository(namespace_name, repository_name)
  except Repository.DoesNotExist:
    raise DataModelException('Invalid repository %s/%s' % (namespace_name, repository_name))

  now_ts = get_epoch_timestamp()

  with db_transaction():
    try:
      tag = db_for_update(_tag_alive(RepositoryTag
                                     .select()
                                     .where(RepositoryTag.repository == repo,
                                            RepositoryTag.name == tag_name), now_ts)).get()
      tag.lifetime_end_ts = now_ts
      tag.save()
    except RepositoryTag.DoesNotExist:
      pass
    except IntegrityError:
      msg = 'Tag with name %s was stale when we tried to update it; Please retry the push'
      raise StaleTagException(msg % tag_name)

    try:
      image_obj = Image.get(Image.docker_image_id == tag_docker_image_id, Image.repository == repo)
    except Image.DoesNotExist:
      raise DataModelException('Invalid image with id: %s' % tag_docker_image_id)

    try:
      return RepositoryTag.create(repository=repo, image=image_obj, name=tag_name,
                                  lifetime_start_ts=now_ts, reversion=reversion)
    except IntegrityError:
      msg = 'Tag with name %s and lifetime start %s under repository %s/%s already exists'
      raise TagAlreadyCreatedException(msg % (tag_name, now_ts, namespace_name, repository_name))


def create_temporary_hidden_tag(repo, image_obj, expiration_s):
  """ Create a tag with a defined timeline, that will not appear in the UI or CLI. Returns the name
      of the temporary tag. """
  now_ts = get_epoch_timestamp()
  expire_ts = now_ts + expiration_s
  tag_name = str(uuid4())
  RepositoryTag.create(repository=repo, image=image_obj, name=tag_name, lifetime_start_ts=now_ts,
                       lifetime_end_ts=expire_ts, hidden=True)
  return tag_name


def delete_tag(namespace_name, repository_name, tag_name):
  now_ts = get_epoch_timestamp()
  with db_transaction():
    try:
      query = _tag_alive(RepositoryTag
                         .select(RepositoryTag, Repository)
                         .join(Repository)
                         .join(Namespace, on=(Repository.namespace_user == Namespace.id))
                         .where(Repository.name == repository_name,
                                Namespace.username == namespace_name,
                                RepositoryTag.name == tag_name), now_ts)
      found = db_for_update(query).get()
    except RepositoryTag.DoesNotExist:
      msg = ('Invalid repository tag \'%s\' on repository \'%s/%s\'' %
             (tag_name, namespace_name, repository_name))
      raise DataModelException(msg)

    found.lifetime_end_ts = now_ts
    found.save()
    return found


def garbage_collect_tags(repo):
  """ Remove all of the tags that have gone past their garbage collection
      expiration window, and return a set of image ids which *may* have been
      orphaned.
      """
  def add_expiration_data(base_query):
    expired_clause = get_epoch_timestamp() - Namespace.removed_tag_expiration_s
    return (base_query
            .switch(RepositoryTag)
            .join(Repository)
            .join(Namespace, on=(Repository.namespace_user == Namespace.id))
            .where(~(RepositoryTag.lifetime_end_ts >> None),
                   RepositoryTag.lifetime_end_ts <= expired_clause))
  return _delete_tags(repo, add_expiration_data)

def purge_all_tags(repo):
  """ Remove all tags from the repository, and return a set of all of the images
      ids which are now orphaned.
  """
  return _delete_tags(repo)

def _delete_tags(repo, query_modifier=None):
  """ Garbage collect the tags for a repository and return a set of the image
      ids which may now be orphaned.
  """
  tags_to_delete_q = (RepositoryTag
                      .select(RepositoryTag.id, Image.ancestors, Image.id)
                      .join(Image)
                      .where(RepositoryTag.repository == repo))

  if query_modifier is not None:
    tags_to_delete_q = query_modifier(tags_to_delete_q)

  tags_to_delete = list(tags_to_delete_q)

  if len(tags_to_delete) == 0:
    return set()

  with db_transaction():
    manifests_to_delete = list(TagManifest
                               .select(TagManifest.id)
                               .join(RepositoryTag)
                               .where(RepositoryTag.id << tags_to_delete))

    num_deleted_manifests = 0
    if len(manifests_to_delete) > 0:
      # Find the set of IDs for all the labels to delete.
      manifest_labels_query = (TagManifestLabel
                               .select()
                               .where(TagManifestLabel.repository == repo,
                                      TagManifestLabel.annotated << manifests_to_delete))

      label_ids = [manifest_label.label_id for manifest_label in manifest_labels_query]
      if label_ids:
        # Delete all the mapping entries.
        (TagManifestLabel
         .delete()
         .where(TagManifestLabel.repository == repo,
                TagManifestLabel.annotated << manifests_to_delete)
         .execute())

        # Delete all the matching labels.
        Label.delete().where(Label.id << label_ids).execute()


      num_deleted_manifests = (TagManifest
                               .delete()
                               .where(TagManifest.id << manifests_to_delete)
                               .execute())

    num_deleted_tags = (RepositoryTag
                        .delete()
                        .where(RepositoryTag.id << tags_to_delete)
                        .execute())

    logger.debug('Removed %s tags with %s manifests', num_deleted_tags, num_deleted_manifests)

    ancestors = reduce(lambda r, l: r | l,
                       (set(tag.image.ancestor_id_list()) for tag in tags_to_delete))
    direct_referenced = {tag.image.id for tag in tags_to_delete}
    return ancestors | direct_referenced


def _get_repo_tag_image(tag_name, include_storage, modifier):
  query = Image.select().join(RepositoryTag)

  if include_storage:
    query = (Image
             .select(Image, ImageStorage)
             .join(ImageStorage)
             .switch(Image)
             .join(RepositoryTag))

  images = _tag_alive(modifier(query.where(RepositoryTag.name == tag_name)))
  if not images:
    raise DataModelException('Unable to find image for tag.')
  else:
    return images[0]


def get_repo_tag_image(repo, tag_name, include_storage=False):
  def modifier(query):
    return query.where(RepositoryTag.repository == repo)

  return _get_repo_tag_image(tag_name, include_storage, modifier)


def get_tag_image(namespace_name, repository_name, tag_name, include_storage=False):
  def modifier(query):
    return (query
            .switch(RepositoryTag)
            .join(Repository)
            .join(Namespace)
            .where(Namespace.username == namespace_name, Repository.name == repository_name))

  return _get_repo_tag_image(tag_name, include_storage, modifier)


def list_repository_tag_history(repo_obj, page=1, size=100, specific_tag=None):
  query = (RepositoryTag
           .select(RepositoryTag, Image)
           .join(Image)
           .switch(RepositoryTag)
           .where(RepositoryTag.repository == repo_obj)
           .where(RepositoryTag.hidden == False)
           .order_by(RepositoryTag.lifetime_start_ts.desc(), RepositoryTag.name)
           .limit(size + 1)
           .offset(size * (page - 1)))

  if specific_tag:
    query = query.where(RepositoryTag.name == specific_tag)

  tags = list(query)
  if not tags:
    return [], {}, False

  manifest_map = get_tag_manifest_digests(tags)
  return tags[0:size], manifest_map, len(tags) > size


def restore_tag_to_manifest(repo_obj, tag_name, manifest_digest):
  """ Restores a tag to a specific manifest digest. """
  with db_transaction():
    # Verify that the manifest digest already existed under this repository under the
    # tag.
    try:
      manifest = (TagManifest
                  .select(TagManifest, RepositoryTag, Image)
                  .join(RepositoryTag)
                  .join(Image)
                  .where(RepositoryTag.repository == repo_obj)
                  .where(RepositoryTag.name == tag_name)
                  .where(TagManifest.digest == manifest_digest)
                  .get())
    except TagManifest.DoesNotExist:
      raise DataModelException('Cannot restore to unknown or invalid digest')

    # Lookup the existing image, if any.
    try:
      existing_image = get_repo_tag_image(repo_obj, tag_name)
    except DataModelException:
      existing_image = None

    docker_image_id = manifest.tag.image.docker_image_id
    store_tag_manifest(repo_obj.namespace_user.username, repo_obj.name, tag_name, docker_image_id,
                       manifest_digest, manifest.json_data, reversion=True)
    return existing_image


def restore_tag_to_image(repo_obj, tag_name, docker_image_id):
  """ Restores a tag to a specific image ID. """
  with db_transaction():
    # Verify that the image ID already existed under this repository under the
    # tag.
    try:
      (RepositoryTag
       .select()
       .join(Image)
       .where(RepositoryTag.repository == repo_obj)
       .where(RepositoryTag.name == tag_name)
       .where(Image.docker_image_id == docker_image_id)
       .get())
    except RepositoryTag.DoesNotExist:
      raise DataModelException('Cannot restore to unknown or invalid image')

    # Lookup the existing image, if any.
    try:
      existing_image = get_repo_tag_image(repo_obj, tag_name)
    except DataModelException:
      existing_image = None

    create_or_update_tag(repo_obj.namespace_user.username, repo_obj.name, tag_name,
                         docker_image_id, reversion=True)
    return existing_image


def store_tag_manifest(namespace, repo_name, tag_name, docker_image_id, manifest_digest,
                       manifest_data, reversion=False):
  """ Stores a tag manifest for a specific tag name in the database. Returns the TagManifest
      object, as well as a boolean indicating whether the TagManifest was created.
  """
  with db_transaction():
    tag = create_or_update_tag(namespace, repo_name, tag_name, docker_image_id, reversion=reversion)

    try:
      manifest = TagManifest.get(digest=manifest_digest)
      manifest.tag = tag
      manifest.save()
      return manifest, False
    except TagManifest.DoesNotExist:
      return TagManifest.create(tag=tag, digest=manifest_digest, json_data=manifest_data), True


def get_active_tag(namespace, repo_name, tag_name):
  return _tag_alive(RepositoryTag
                    .select()
                    .join(Repository)
                    .join(Namespace, on=(Repository.namespace_user == Namespace.id))
                    .where(RepositoryTag.name == tag_name, Repository.name == repo_name,
                           Namespace.username == namespace)).get()


def associate_generated_tag_manifest(namespace, repo_name, tag_name, manifest_digest,
                                     manifest_data):
  tag = get_active_tag(namespace, repo_name, tag_name)
  return TagManifest.create(tag=tag, digest=manifest_digest, json_data=manifest_data)


def load_tag_manifest(namespace, repo_name, tag_name):
  try:
    return (_load_repo_manifests(namespace, repo_name)
            .where(RepositoryTag.name == tag_name)
            .get())
  except TagManifest.DoesNotExist:
    msg = 'Manifest not found for tag {0} in repo {1}/{2}'.format(tag_name, namespace, repo_name)
    raise InvalidManifestException(msg)


def delete_manifest_by_digest(namespace, repo_name, digest):
  tag_manifests = list(_load_repo_manifests(namespace, repo_name)
                       .where(TagManifest.digest == digest))

  for tag_manifest in tag_manifests:
    delete_tag(namespace, repo_name, tag_manifest.tag.name)

  return [tag_manifest.tag for tag_manifest in tag_manifests]


def load_manifest_by_digest(namespace, repo_name, digest):
  try:
    return (_load_repo_manifests(namespace, repo_name)
            .where(TagManifest.digest == digest)
            .get())
  except TagManifest.DoesNotExist:
    msg = 'Manifest not found with digest {0} in repo {1}/{2}'.format(digest, namespace, repo_name)
    raise InvalidManifestException(msg)


def _load_repo_manifests(namespace, repo_name):
  return _tag_alive(TagManifest
                    .select(TagManifest, RepositoryTag)
                    .join(RepositoryTag)
                    .join(Image)
                    .join(Repository)
                    .join(Namespace, on=(Namespace.id == Repository.namespace_user))
                    .where(Repository.name == repo_name, Namespace.username == namespace))