class ImageTreeNode(object):
  """ A node in the image tree. """
  def __init__(self, image):
    self.image = image
    self.parent = None
    self.children = []
    self.tags = []

  def add_child(self, child):
    self.children.append(child)
    child.parent = self

  def add_tag(self, tag):
    self.tags.append(tag)


class ImageTree(object):
  """ In-memory tree for easy traversal and lookup of images in a repository. """

  def __init__(self, all_images, all_tags, base_filter=None):
    self._tag_map = {}
    self._image_map = {}

    self._build(all_images, all_tags, base_filter)

  def _build(self,  all_images, all_tags, base_filter=None):
    # Build nodes for each of the images.
    for image in all_images:
      ancestors = image.ancestors.split('/')[1:-1]

      # Filter any unneeded images.
      if base_filter is not None:
        if image.id != base_filter and not str(base_filter) in ancestors:
          continue

      self._image_map[image.id] = ImageTreeNode(image)

    # Connect the nodes to their parents.
    for image_node in self._image_map.values():
      image = image_node.image
      parent_image_id = image.ancestors.split('/')[-2] if image.ancestors else None
      if not parent_image_id:
        continue

      parent_node = self._image_map.get(int(parent_image_id))
      if parent_node is not None:
        parent_node.add_child(image_node)

    # Build the tag map.
    for tag in all_tags:
      image_node = self._image_map.get(tag.image.id)
      if not image_node:
        continue

      self._tag_map = image_node
      image_node.add_tag(tag.name)


  def find_longest_path(self, image_id, checker):
    """ Returns a list of images representing the longest path that matches the given
        checker function, starting from the given image_id *exclusive*.
    """
    start_node = self._image_map.get(image_id)
    if not start_node:
      return []

    return self._find_longest_path(start_node, checker, -1)[1:]


  def _find_longest_path(self, image_node, checker, index):
    found_path = []

    for child_node in image_node.children:
      if not checker(index + 1, child_node.image):
        continue

      found = self._find_longest_path(child_node, checker, index + 1)
      if found and len(found) > len(found_path):
        found_path = found

    return [image_node.image] + found_path


  def tag_containing_image(self, image):
    """ Returns the name of the closest tag containing the given image. """
    if not image:
      return None

    # Check the current image for a tag.
    image_node = self._image_map.get(image.id)
    if image_node is None:
      return None

    if image_node.tags:
      return image_node.tags[0]

    # Check any deriving images for a tag.
    for child_node in image_node.children:
      found = self.tag_containing_image(child_node.image)
      if found is not None:
        return found

    return None