import marisa_trie
import os
import tarfile
from aufs import is_aufs_metadata, get_deleted_prefix


AUFS_METADATA = u'.wh..wh.'

AUFS_WHITEOUT = u'.wh.'
AUFS_WHITEOUT_PREFIX_LENGTH = len(AUFS_WHITEOUT)

class StreamLayerMerger(object):
  """ Class which creates a generator of the combined TAR data for a set of Docker layers. """
  def __init__(self, layer_iterator):
    self.trie = marisa_trie.Trie()
    self.layer_iterator = layer_iterator
    self.encountered = []

  def get_generator(self):
    for current_layer in self.layer_iterator():
      # Read the current layer as TAR. If it is empty, we just continue
      # to the next layer.
      try:
        tar_file = tarfile.open(mode='r|*', fileobj=current_layer)
      except tarfile.ReadError as re:
        continue

      # For each of the tar entries, yield them IF and ONLY IF we have not
      # encountered the path before.

      # 9MB (+ padding below) so that it matches the 10MB expected by Gzip.
      chunk_size = 1024 * 1024 * 9

      for tar_info in tar_file:
        if not self.check_tar_info(tar_info):
          continue

        # Yield the tar header.
        yield tar_info.tobuf()

        # Try to extract any file contents for the tar. If found, we yield them as well.
        if tar_info.isreg():
          file_stream = tar_file.extractfile(tar_info)
          if file_stream is not None:
            length = 0
            while True:
              current_block = file_stream.read(chunk_size)
              if not len(current_block):
                break

              yield current_block
              length += len(current_block)

            file_stream.close()

            # Files must be padding to 512 byte multiples.
            if length % 512 != 0:
              yield '\0' * (512 - (length % 512))

      # Close the layer stream now that we're done with it.
      tar_file.close()

      # Update the trie with the new encountered entries.
      self.trie = marisa_trie.Trie(self.encountered)
      
    # Last two records are empty in TAR spec.
    yield '\0' * 512
    yield '\0' * 512


  def check_tar_info(self, tar_info):
    absolute = os.path.relpath(tar_info.name.decode('utf-8'), './')

    # Skip metadata.
    if is_aufs_metadata(absolute):
      return False

    # Add any prefix of deleted paths to the prefix list.
    deleted_prefix = get_deleted_prefix(absolute)
    if deleted_prefix is not None:
      self.encountered.append(deleted_prefix)
      return False

    # Check if this file has already been encountered somewhere. If so,
    # skip it.
    if unicode(absolute) in self.trie:
      return False

    # Otherwise, add the path to the encountered list and return it.
    self.encountered.append(absolute)
    return True