import logging

from peewee import JOIN_LEFT_OUTER, fn, SQL

from data.model import config, db_transaction, InvalidImageException
from data.database import (ImageStorage, Image, DerivedImageStorage, ImageStoragePlacement,
                           ImageStorageLocation, ImageStorageTransformation, ImageStorageSignature,
                           ImageStorageSignatureKind, Repository, Namespace)


logger = logging.getLogger(__name__)


def add_storage_placement(storage, location_name):
  """ Adds a storage placement for the given storage at the given location. """
  location = ImageStorageLocation.get(name=location_name)
  ImageStoragePlacement.create(location=location, storage=storage)


def find_or_create_derived_storage(source, transformation_name, preferred_location):
  existing = find_derived_storage(source, transformation_name)
  if existing is not None:
    return existing

  logger.debug('Creating storage dervied from source: %s', source.uuid)
  trans = ImageStorageTransformation.get(name=transformation_name)
  new_storage = create_v1_storage(preferred_location)
  DerivedImageStorage.create(source=source, derivative=new_storage, transformation=trans)
  return new_storage


def garbage_collect_storage(storage_id_whitelist):
  if len(storage_id_whitelist) == 0:
    return

  def placements_query_to_paths_set(placements_query):
    return {(placement.location.name, get_layer_path(placement.storage))
            for placement in placements_query}

  def orphaned_storage_query(select_base_query, candidates, group_by):
    return (select_base_query
            .switch(ImageStorage)
            .join(Image, JOIN_LEFT_OUTER)
            .switch(ImageStorage)
            .join(DerivedImageStorage, JOIN_LEFT_OUTER,
                  on=(ImageStorage.id == DerivedImageStorage.derivative))
            .where(ImageStorage.id << list(candidates))
            .group_by(*group_by)
            .having((fn.Count(Image.id) == 0) & (fn.Count(DerivedImageStorage.id) == 0)))

  # Note: We remove the derived image storage in its own transaction as a way to reduce the
  # time that the transaction holds on the database indicies. This could result in a derived
  # image storage being deleted for an image storage which is later reused during this time,
  # but since these are caches anyway, it isn't terrible and worth the tradeoff (for now).
  logger.debug('Garbage collecting derived storage from candidates: %s', storage_id_whitelist)
  with db_transaction():
    # Find out which derived storages will be removed, and add them to the whitelist
    # The comma after ImageStorage.id is VERY important, it makes it a tuple, which is a sequence
    orphaned_from_candidates = list(orphaned_storage_query(ImageStorage.select(ImageStorage.id),
                                                           storage_id_whitelist,
                                                           (ImageStorage.id,)))

    if len(orphaned_from_candidates) > 0:
      derived_to_remove = (ImageStorage
                           .select(ImageStorage.id)
                           .join(DerivedImageStorage,
                                 on=(ImageStorage.id == DerivedImageStorage.derivative))
                           .where(DerivedImageStorage.source << orphaned_from_candidates))
      storage_id_whitelist.update({derived.id for derived in derived_to_remove})

      # Remove the dervived image storages with sources of orphaned storages
      (DerivedImageStorage
       .delete()
       .where(DerivedImageStorage.source << orphaned_from_candidates)
       .execute())

  # Note: Both of these deletes must occur in the same transaction (unfortunately) because a
  # storage without any placement is invalid, and a placement cannot exist without a storage.
  # TODO(jake): We might want to allow for null storages on placements, which would allow us to
  # delete the storages, then delete the placements in a non-transaction.
  logger.debug('Garbage collecting storages from candidates: %s', storage_id_whitelist)
  with db_transaction():
    # Track all of the data that should be removed from blob storage
    placements_to_remove = list(orphaned_storage_query(ImageStoragePlacement
                                                       .select(ImageStoragePlacement,
                                                               ImageStorage,
                                                               ImageStorageLocation)
                                                       .join(ImageStorageLocation)
                                                       .switch(ImageStoragePlacement)
                                                       .join(ImageStorage),
                                                       storage_id_whitelist,
                                                       (ImageStorage, ImageStoragePlacement,
                                                        ImageStorageLocation)))

    paths_to_remove = placements_query_to_paths_set(placements_to_remove)

    # Remove the placements for orphaned storages
    if len(placements_to_remove) > 0:
      placement_ids_to_remove = [placement.id for placement in placements_to_remove]
      placements_removed = (ImageStoragePlacement
                            .delete()
                            .where(ImageStoragePlacement.id << placement_ids_to_remove)
                            .execute())
      logger.debug('Removed %s image storage placements', placements_removed)

    # Remove all orphaned storages
    # The comma after ImageStorage.id is VERY important, it makes it a tuple, which is a sequence
    orphaned_storages = list(orphaned_storage_query(ImageStorage.select(ImageStorage.id),
                                                    storage_id_whitelist,
                                                    (ImageStorage.id,)).alias('osq'))
    if len(orphaned_storages) > 0:
      storages_removed = (ImageStorage
                          .delete()
                          .where(ImageStorage.id << orphaned_storages)
                          .execute())
      logger.debug('Removed %s image storage records', storages_removed)

  # We are going to make the conscious decision to not delete image storage blobs inside
  # transactions.
  # This may end up producing garbage in s3, trading off for higher availability in the database.
  for location_name, image_path in paths_to_remove:
    logger.debug('Removing %s from %s', image_path, location_name)
    config.store.remove({location_name}, image_path)


def create_v1_storage(location_name):
  storage = ImageStorage.create(cas_path=False)
  location = ImageStorageLocation.get(name=location_name)
  ImageStoragePlacement.create(location=location, storage=storage)
  storage.locations = {location_name}
  return storage


def find_or_create_storage_signature(storage, signature_kind):
  found = lookup_storage_signature(storage, signature_kind)
  if found is None:
    kind = ImageStorageSignatureKind.get(name=signature_kind)
    found = ImageStorageSignature.create(storage=storage, kind=kind)

  return found


def lookup_storage_signature(storage, signature_kind):
  kind = ImageStorageSignatureKind.get(name=signature_kind)
  try:
    return (ImageStorageSignature
            .select()
            .where(ImageStorageSignature.storage == storage, ImageStorageSignature.kind == kind)
            .get())
  except ImageStorageSignature.DoesNotExist:
    return None


def find_derived_storage(source, transformation_name):
  try:
    found = (ImageStorage
             .select(ImageStorage, DerivedImageStorage)
             .join(DerivedImageStorage, on=(ImageStorage.id == DerivedImageStorage.derivative))
             .join(ImageStorageTransformation)
             .where(DerivedImageStorage.source == source,
                    ImageStorageTransformation.name == transformation_name)
             .get())

    found.locations = {placement.location.name for placement in found.imagestorageplacement_set}
    return found
  except ImageStorage.DoesNotExist:
    return None


def delete_derived_storage_by_uuid(storage_uuid):
  try:
    image_storage = get_storage_by_uuid(storage_uuid)
  except InvalidImageException:
    return

  try:
    DerivedImageStorage.get(derivative=image_storage)
  except DerivedImageStorage.DoesNotExist:
    return

  image_storage.delete_instance(recursive=True)


def _get_storage(query_modifier):
  query = (ImageStoragePlacement
           .select(ImageStoragePlacement, ImageStorage, ImageStorageLocation)
           .join(ImageStorageLocation)
           .switch(ImageStoragePlacement)
           .join(ImageStorage))

  placements = list(query_modifier(query))

  if not placements:
    raise InvalidImageException()

  found = placements[0].storage
  found.locations = {placement.location.name for placement in placements}

  return found


def get_storage_by_uuid(storage_uuid):
  def filter_to_uuid(query):
    return query.where(ImageStorage.uuid == storage_uuid)

  try:
    return _get_storage(filter_to_uuid)
  except InvalidImageException:
    raise InvalidImageException('No storage found with uuid: %s', storage_uuid)


def get_layer_path(storage_record):
  """ Returns the path in the storage engine to the layer data referenced by the storage row. """
  store = config.store
  if not storage_record.cas_path:
    logger.debug('Serving layer from legacy v1 path')
    return store.v1_image_layer_path(storage_record.uuid)

  return store.blob_path(storage_record.content_checksum)


def lookup_repo_storages_by_content_checksum(repo, checksums):
  """ Looks up repository storages (without placements) matching the given repository
      and checksum. """
  # There may be many duplicates of the checksums, so for performance reasons we are going
  # to use a union to select just one storage with each checksum
  queries = []

  for counter, checksum in enumerate(set(checksums)):
    query_alias = 'q{0}'.format(counter)
    candidate_subq = (ImageStorage
                      .select(ImageStorage.id, ImageStorage.content_checksum, ImageStorage.image_size)
                      .join(Image)
                      .where(Image.repository == repo, ImageStorage.content_checksum == checksum)
                      .limit(1)
                      .alias(query_alias))
    queries.append(ImageStorage
                   .select(SQL('*'))
                   .from_(candidate_subq))

  return reduce(lambda l, r: l.union_all(r), queries)


def get_storage_locations(uuid):
  query = (ImageStoragePlacement
           .select()
           .join(ImageStorageLocation)
           .switch(ImageStoragePlacement)
           .join(ImageStorage, JOIN_LEFT_OUTER)
           .where(ImageStorage.uuid == uuid))

  return [location.location.name for location in query]