class ImageTreeNode(object):
  """ A node in the image tree. """
  def __init__(self, image, child_map):
    self.image = image
    self.parent = None
    self.tags = []

    self._child_map = child_map

  @property
  def children(self):
    return self._child_map.get(str(self.image.id), [])

  def add_tag(self, tag):
    self.tags.append(tag)


class ImageTree(object):
  """ In-memory tree for easy traversal and lookup of images in a repository. """

  def __init__(self, all_images, all_tags, base_filter=None):
    self._image_map = {}
    self._child_map = {}

    self._build(all_images, all_tags, base_filter)

  def _build(self,  all_images, all_tags, base_filter=None):
    # Build nodes for each of the images.
    for image in all_images:
      ancestors = image.ancestors.split('/')[1:-1]

      # Filter any unneeded images.
      if base_filter is not None:
        if image.id != base_filter and not str(base_filter) in ancestors:
          continue

      # Create the node for the image.
      image_node = ImageTreeNode(image, self._child_map)
      self._image_map[image.id] = image_node

      # Add the node to the child map for its parent image (if any).
      parent_image_id = image.ancestors.split('/')[-2] if image.ancestors else None
      if parent_image_id:
        if not parent_image_id in self._child_map:
          self._child_map[parent_image_id] = []

        self._child_map[parent_image_id].append(image_node)

    # Build the tag map.
    for tag in all_tags:
      image_node = self._image_map.get(tag.image.id)
      if not image_node:
        continue

      image_node.add_tag(tag.name)


  def find_longest_path(self, image_id, checker):
    """ Returns a list of images representing the longest path that matches the given
        checker function, starting from the given image_id *exclusive*.
    """
    start_node = self._image_map.get(image_id)
    if not start_node:
      return []

    return self._find_longest_path(start_node, checker, -1)[1:]


  def _find_longest_path(self, image_node, checker, index):
    found_path = []

    for child_node in image_node.children:
      if not checker(index + 1, child_node.image):
        continue

      found = self._find_longest_path(child_node, checker, index + 1)
      if found and len(found) > len(found_path):
        found_path = found

    return [image_node.image] + found_path


  def tag_containing_image(self, image):
    """ Returns the name of the closest tag containing the given image. """
    if not image:
      return None

    # Check the current image for a tag.
    image_node = self._image_map.get(image.id)
    if image_node is None:
      return None

    if image_node.tags:
      return image_node.tags[0]

    # Check any deriving images for a tag.
    for child_node in image_node.children:
      found = self.tag_containing_image(child_node.image)
      if found is not None:
        return found

    return None