import logging
import dateutil.parser
import random

from peewee import JOIN_LEFT_OUTER, fn, SQL
from datetime import datetime

from data.model import (DataModelException, db_transaction, _basequery, storage,
                        InvalidImageException, config)
from data.database import (Image, Repository, ImageStoragePlacement, Namespace, ImageStorage,
                           ImageStorageLocation, RepositoryPermission, DerivedStorageForImage,
                           ImageStorageTransformation, db_random_func, db_for_update)


logger = logging.getLogger(__name__)


def get_repository_image_and_deriving(docker_image_id, storage_uuid):
  """ Returns all matching images with the given docker image ID and storage uuid, along with any
      images which have the image ID as parents.
  """
  try:
    image_found = (Image
                    .select()
                    .join(ImageStorage)
                    .where(Image.docker_image_id == docker_image_id,
                           ImageStorage.uuid == storage_uuid)
                    .get())
  except Image.DoesNotExist:
    return Image.select().where(Image.id < 0) # Empty query

  ancestors_pattern = '%s%s/%%' % (image_found.ancestors, image_found.id)
  return Image.select().where((Image.ancestors ** ancestors_pattern) |
                              (Image.id == image_found.id))


def get_parent_images(namespace_name, repository_name, image_obj):
  """ Returns a list of parent Image objects starting with the most recent parent
      and ending with the base layer.
  """
  parents = image_obj.ancestors

  # Ancestors are in the format /<root>/<intermediate>/.../<parent>/, with each path section
  # containing the database Id of the image row.
  parent_db_ids = parents.strip('/').split('/')

  if parent_db_ids == ['']:
    return []

  def filter_to_parents(query):
    return query.where(Image.id << parent_db_ids)

  parents = get_repository_images_base(namespace_name, repository_name, filter_to_parents)

  id_to_image = {unicode(image.id): image for image in parents}

  return [id_to_image[parent_id] for parent_id in reversed(parent_db_ids)]


def get_repo_image(namespace_name, repository_name, docker_image_id):
  def limit_to_image_id(query):
    return query.where(Image.docker_image_id == docker_image_id).limit(1)

  query = _get_repository_images(namespace_name, repository_name, limit_to_image_id)
  try:
    return query.get()
  except Image.DoesNotExist:
    return None


def get_repo_image_extended(namespace_name, repository_name, docker_image_id):
  def limit_to_image_id(query):
    return query.where(Image.docker_image_id == docker_image_id)

  images = get_repository_images_base(namespace_name, repository_name, limit_to_image_id)
  if not images:
    return None

  return images[0]


def _get_repository_images(namespace_name, repository_name, query_modifier):
  query = (Image
           .select()
           .join(Repository)
           .join(Namespace, on=(Repository.namespace_user == Namespace.id))
           .where(Repository.name == repository_name, Namespace.username == namespace_name))

  query = query_modifier(query)
  return query


def get_repository_images_base(namespace_name, repository_name, query_modifier):
  query = (ImageStoragePlacement
           .select(ImageStoragePlacement, Image, ImageStorage, ImageStorageLocation)
           .join(ImageStorageLocation)
           .switch(ImageStoragePlacement)
           .join(ImageStorage, JOIN_LEFT_OUTER)
           .join(Image)
           .join(Repository)
           .join(Namespace, on=(Repository.namespace_user == Namespace.id))
           .where(Repository.name == repository_name, Namespace.username == namespace_name))

  query = query_modifier(query)
  return invert_placement_query_results(query)


def invert_placement_query_results(placement_query):
  """ This method will take a query which returns placements, storages, and images, and have it
      return images and their storages, along with the placement set on each storage.
      """
  location_list = list(placement_query)

  images = {}
  for location in location_list:
    # Make sure we're always retrieving the same image object.
    image = location.storage.image

    # Set the storage to the one we got from the location, to prevent another query
    image.storage = location.storage

    if not image.id in images:
      images[image.id] = image
      image.storage.locations = set()
    else:
      image = images[image.id]

    # Add the location to the image's locations set.
    image.storage.locations.add(location.location.name)

  return images.values()


def lookup_repository_images(repo, docker_image_ids):
  return (Image
          .select()
          .where(Image.repository == repo, Image.docker_image_id << docker_image_ids))


def get_matching_repository_images(namespace_name, repository_name, docker_image_ids):
  def modify_query(query):
    return query.where(Image.docker_image_id << list(docker_image_ids))

  return get_repository_images_base(namespace_name, repository_name, modify_query)


def get_repository_images_without_placements(repo_obj, with_ancestor=None):
  query = (Image
           .select(Image, ImageStorage)
           .join(ImageStorage)
           .where(Image.repository == repo_obj))

  if with_ancestor:
    ancestors_string = '%s%s/' % (with_ancestor.ancestors, with_ancestor.id)
    query = query.where((Image.ancestors ** (ancestors_string + '%')) |
                        (Image.id == with_ancestor.id))

  return query


def get_repository_images(namespace_name, repository_name):
  return get_repository_images_base(namespace_name, repository_name, lambda q: q)


def get_image_by_id(namespace_name, repository_name, docker_image_id):
  image = get_repo_image_extended(namespace_name, repository_name, docker_image_id)
  if not image:
    raise InvalidImageException('Unable to find image \'%s\' for repo \'%s/%s\'' %
                                (docker_image_id, namespace_name, repository_name))
  return image


def __translate_ancestry(old_ancestry, translations, repo_obj, username, preferred_location):
  if old_ancestry == '/':
    return '/'

  def translate_id(old_id, docker_image_id):
    logger.debug('Translating id: %s', old_id)
    if old_id not in translations:
      image_in_repo = find_create_or_link_image(docker_image_id, repo_obj, username, translations,
                                                preferred_location)
      translations[old_id] = image_in_repo.id
    return translations[old_id]

  # Select all the ancestor Docker IDs in a single query.
  old_ids = [int(id_str) for id_str in old_ancestry.split('/')[1:-1]]
  query = Image.select(Image.id, Image.docker_image_id).where(Image.id << old_ids)
  old_images = {i.id: i.docker_image_id for i in  query}

  # Translate the old images into new ones.
  new_ids = [str(translate_id(old_id, old_images[old_id])) for old_id in old_ids]
  return '/%s/' % '/'.join(new_ids)


def _find_or_link_image(existing_image, repo_obj, username, translations, preferred_location):
  # TODO(jake): This call is currently recursively done under a single transaction. Can we make
  # it instead be done under a set of transactions?
  with db_transaction():
    # Check for an existing image, under the transaction, to make sure it doesn't already exist.
    repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
                                existing_image.docker_image_id)
    if repo_image:
      return repo_image

    # Make sure the existing base image still exists.
    try:
      to_copy = Image.select().join(ImageStorage).where(Image.id == existing_image.id).get()

      msg = 'Linking image to existing storage with docker id: %s and uuid: %s'
      logger.debug(msg, existing_image.docker_image_id, to_copy.storage.uuid)

      new_image_ancestry = __translate_ancestry(to_copy.ancestors, translations, repo_obj,
                                                username, preferred_location)

      copied_storage = to_copy.storage
      copied_storage.locations = {placement.location.name
                                  for placement in copied_storage.imagestorageplacement_set}

      translated_parent_id = None
      if new_image_ancestry != '/':
        translated_parent_id = int(new_image_ancestry.split('/')[-2])

      new_image = Image.create(docker_image_id=existing_image.docker_image_id,
                               repository=repo_obj,
                               storage=copied_storage,
                               ancestors=new_image_ancestry,
                               command=existing_image.command,
                               created=existing_image.created,
                               comment=existing_image.comment,
                               v1_json_metadata=existing_image.v1_json_metadata,
                               aggregate_size=existing_image.aggregate_size,
                               parent=translated_parent_id,
                               v1_checksum=existing_image.v1_checksum)


      logger.debug('Storing translation %s -> %s', existing_image.id, new_image.id)
      translations[existing_image.id] = new_image.id
      return new_image
    except Image.DoesNotExist:
      return None


def find_create_or_link_image(docker_image_id, repo_obj, username, translations,
                              preferred_location):

  # First check for the image existing in the repository. If found, we simply return it.
  repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
                              docker_image_id)
  if repo_image:
    return repo_image

  # We next check to see if there is an existing storage the new image can link to.
  existing_image_query = (Image
                          .select(Image, ImageStorage)
                          .distinct()
                          .join(ImageStorage)
                          .switch(Image)
                          .join(Repository)
                          .join(RepositoryPermission, JOIN_LEFT_OUTER)
                          .switch(Repository)
                          .join(Namespace, on=(Repository.namespace_user == Namespace.id))
                          .where(ImageStorage.uploading == False,
                                 Image.docker_image_id == docker_image_id))

  existing_image_query = _basequery.filter_to_repos_for_user(existing_image_query, username)

  # If there is an existing image, we try to translate its ancestry and copy its storage.
  new_image = None
  try:
    logger.debug('Looking up existing image for ID: %s', docker_image_id)
    existing_image = existing_image_query.get()

    logger.debug('Existing image %s found for ID: %s', existing_image.id, docker_image_id)
    new_image = _find_or_link_image(existing_image, repo_obj, username, translations,
                                    preferred_location)
    if new_image:
      return new_image
  except Image.DoesNotExist:
    logger.debug('No existing image found for ID: %s', docker_image_id)

  # Otherwise, create a new storage directly.
  with db_transaction():
    # Final check for an existing image, under the transaction.
    repo_image = get_repo_image(repo_obj.namespace_user.username, repo_obj.name,
                                docker_image_id)
    if repo_image:
      return repo_image

    logger.debug('Creating new storage for docker id: %s', docker_image_id)
    new_storage = storage.create_v1_storage(preferred_location)

    return Image.create(docker_image_id=docker_image_id,
                        repository=repo_obj, storage=new_storage,
                        ancestors='/')


def set_image_metadata(docker_image_id, namespace_name, repository_name, created_date_str, comment,
                       command, v1_json_metadata, parent=None):
  with db_transaction():
    query = (Image
             .select(Image, ImageStorage)
             .join(Repository)
             .join(Namespace, on=(Repository.namespace_user == Namespace.id))
             .switch(Image)
             .join(ImageStorage)
             .where(Repository.name == repository_name, Namespace.username == namespace_name,
                    Image.docker_image_id == docker_image_id))

    try:
      fetched = db_for_update(query).get()
    except Image.DoesNotExist:
      raise DataModelException('No image with specified id and repository')

    fetched.created = datetime.now()
    if created_date_str is not None:
      try:
        fetched.created = dateutil.parser.parse(created_date_str).replace(tzinfo=None)
      except:
        # parse raises different exceptions, so we cannot use a specific kind of handler here.
        pass

    # We cleanup any old checksum in case it's a retry after a fail
    fetched.v1_checksum = None
    fetched.storage.content_checksum = None

    fetched.comment = comment
    fetched.command = command
    fetched.v1_json_metadata = v1_json_metadata

    if parent:
      fetched.ancestors = '%s%s/' % (parent.ancestors, parent.id)
      fetched.parent = parent

    fetched.save()
    return fetched


def set_image_size(docker_image_id, namespace_name, repository_name, image_size, uncompressed_size):
  if image_size is None:
    raise DataModelException('Empty image size field')

  try:
    image = (Image
             .select(Image, ImageStorage)
             .join(Repository)
             .join(Namespace, on=(Repository.namespace_user == Namespace.id))
             .switch(Image)
             .join(ImageStorage, JOIN_LEFT_OUTER)
             .where(Repository.name == repository_name, Namespace.username == namespace_name,
                    Image.docker_image_id == docker_image_id)
             .get())
  except Image.DoesNotExist:
    raise DataModelException('No image with specified id and repository')

  image.storage.image_size = image_size
  image.storage.uncompressed_size = uncompressed_size
  image.storage.save()

  image.aggregate_size = calculate_image_aggregate_size(image.ancestors, image.storage,
                                                        image.parent)
  image.save()

  return image


def calculate_image_aggregate_size(ancestors_str, image_storage, parent_image):
  ancestors = ancestors_str.split('/')[1:-1]
  if not ancestors:
    return image_storage.image_size

  if parent_image is None:
    raise DataModelException('Could not load parent image')

  ancestor_size = parent_image.aggregate_size
  if ancestor_size is not None:
    return ancestor_size + image_storage.image_size

  # Fallback to a slower path if the parent doesn't have an aggregate size saved.
  # TODO: remove this code if/when we do a full backfill.
  ancestor_size = (ImageStorage
                   .select(fn.Sum(ImageStorage.image_size))
                   .join(Image)
                   .where(Image.id << ancestors)
                   .scalar())
  if ancestor_size is None:
    return None

  return ancestor_size + image_storage.image_size


def get_image(repo, docker_image_id):
  try:
    return Image.get(Image.docker_image_id == docker_image_id, Image.repository == repo)
  except Image.DoesNotExist:
    return None


def get_repo_image_by_storage_checksum(namespace, repository_name, storage_checksum):
  try:
    return (Image
            .select()
            .join(ImageStorage)
            .switch(Image)
            .join(Repository)
            .join(Namespace, on=(Namespace.id == Repository.namespace_user))
            .where(Repository.name == repository_name, Namespace.username == namespace,
                   ImageStorage.content_checksum == storage_checksum,
                   ImageStorage.uploading == False)
            .get())
  except Image.DoesNotExist:
    msg = 'Image with storage checksum {0} does not exist in repo {1}/{2}'.format(storage_checksum,
                                                                                  namespace,
                                                                                  repository_name)
    raise InvalidImageException(msg)


def get_image_layers(image):
  """ Returns a list of the full layers of an image, including itself (if specified), sorted
      from base image outward. """
  ancestors = image.ancestors.split('/')[1:-1]
  image_ids = [ancestor_id for ancestor_id in ancestors if ancestor_id]
  image_ids.append(str(image.id))

  query = (ImageStoragePlacement
           .select(ImageStoragePlacement, Image, ImageStorage, ImageStorageLocation)
           .join(ImageStorageLocation)
           .switch(ImageStoragePlacement)
           .join(ImageStorage, JOIN_LEFT_OUTER)
           .join(Image)
           .where(Image.id << image_ids))

  image_list = list(invert_placement_query_results(query))
  image_list.sort(key=lambda image: image_ids.index(str(image.id)))
  return image_list


def synthesize_v1_image(repo, image_storage, docker_image_id, created_date_str,
                        comment, command, v1_json_metadata, parent_image=None):
  """ Find an existing image with this docker image id, and if none exists, write one with the
      specified metadata.
  """
  ancestors = '/'
  if parent_image is not None:
    ancestors = '{0}{1}/'.format(parent_image.ancestors, parent_image.id)

  created = None
  if created_date_str is not None:
    try:
      created = dateutil.parser.parse(created_date_str).replace(tzinfo=None)
    except:
      # parse raises different exceptions, so we cannot use a specific kind of handler here.
      pass

  # Get the aggregate size for the image.
  aggregate_size = calculate_image_aggregate_size(ancestors, image_storage, parent_image)

  return Image.create(docker_image_id=docker_image_id, ancestors=ancestors, comment=comment,
                      command=command, v1_json_metadata=v1_json_metadata, created=created,
                      storage=image_storage, repository=repo, parent=parent_image,
                      aggregate_size=aggregate_size)


def ensure_image_locations(*names):
  with db_transaction():
    locations = ImageStorageLocation.select().where(ImageStorageLocation.name << names)

    insert_names = list(names)

    for location in locations:
      insert_names.remove(location.name)

    if not insert_names:
      return

    data = [{'name': name} for name in insert_names]
    ImageStorageLocation.insert_many(data).execute()

def get_secscan_candidates(engine_version, batch_size):
  Parent = Image.alias()
  ParentImageStorage = ImageStorage.alias()
  rimages = []

  # Collect the images without parents
  candidates = list(Image
                    .select(Image.id)
                    .join(ImageStorage)
                    .where(Image.security_indexed_engine < engine_version,
                           Image.parent >> None,
                           ImageStorage.uploading == False)
                    .limit(batch_size*10))

  if len(candidates) > 0:
    images = (Image
              .select(Image, ImageStorage)
              .join(ImageStorage)
              .where(Image.id << candidates)
              .order_by(db_random_func())
              .limit(batch_size))
    rimages.extend(images)

  # Collect the images with analyzed parents.
  candidates = list(Image
                    .select(Image.id)
                    .join(Parent, on=(Image.parent == Parent.id))
                    .switch(Image)
                    .join(ImageStorage)
                    .where(Image.security_indexed_engine < engine_version,
                           Parent.security_indexed_engine >= engine_version,
                           ImageStorage.uploading == False)
                    .limit(batch_size*10))

  if len(candidates) > 0:
    images = (Image
              .select(Image, ImageStorage, Parent, ParentImageStorage)
              .join(Parent, on=(Image.parent == Parent.id))
              .join(ParentImageStorage, on=(ParentImageStorage.id == Parent.storage))
              .switch(Image)
              .join(ImageStorage)
              .where(Image.id << candidates)
              .order_by(db_random_func())
              .limit(batch_size))
    rimages.extend(images)

  # Shuffle the images, otherwise the images without parents will always be on the top
  random.shuffle(rimages)

  return rimages


def set_secscan_status(image, indexed, version):
  query = (Image
           .select()
           .join(ImageStorage)
           .where(Image.docker_image_id == image.docker_image_id,
                  ImageStorage.uuid == image.storage.uuid))

  ids_to_update = [row.id for row in query]
  if not ids_to_update:
    return

  (Image
   .update(security_indexed=indexed, security_indexed_engine=version)
   .where(Image.id << ids_to_update)
   .execute())


def find_or_create_derived_storage(source_image, transformation_name, preferred_location):
  existing = find_derived_storage_for_image(source_image, transformation_name)
  if existing is not None:
    return existing

  logger.debug('Creating storage dervied from source image: %s', source_image.id)
  trans = ImageStorageTransformation.get(name=transformation_name)
  new_storage = storage.create_v1_storage(preferred_location)
  DerivedStorageForImage.create(source_image=source_image, derivative=new_storage,
                                transformation=trans)
  return new_storage


def find_derived_storage_for_image(source_image, transformation_name):
  try:
    found = (ImageStorage
             .select(ImageStorage, DerivedStorageForImage)
             .join(DerivedStorageForImage)
             .join(ImageStorageTransformation)
             .where(DerivedStorageForImage.source_image == source_image,
                    ImageStorageTransformation.name == transformation_name)
             .get())

    found.locations = {placement.location.name for placement in found.imagestorageplacement_set}
    return found
  except ImageStorage.DoesNotExist:
    return None


def delete_derived_storage_by_uuid(storage_uuid):
  try:
    image_storage = storage.get_storage_by_uuid(storage_uuid)
  except InvalidImageException:
    return

  try:
    DerivedStorageForImage.get(derivative=image_storage)
  except DerivedStorageForImage.DoesNotExist:
    return

  image_storage.delete_instance(recursive=True)