import os
import tarfile
import copy

class TarLayerReadException(Exception):
  """ Exception raised when reading a layer has failed. """
  pass


class TarLayerFormat(object):
  """ Class which creates a generator of the combined TAR data. """
  def __init__(self, tar_iterator, path_prefix=None):
    self.tar_iterator = tar_iterator
    self.path_prefix = path_prefix

  def get_generator(self):
    for current_tar in self.tar_iterator():
      # Read the current TAR. If it is empty, we just continue
      # to the next one.
      tar_file = None
      try:
        tar_file = tarfile.open(mode='r|*', fileobj=current_tar)
      except tarfile.ReadError as re:
        if re.message != 'empty file':
          raise TarLayerReadException('Could not read layer')

      if not tar_file:
        continue

      # For each of the tar entries, yield them IF and ONLY IF we have not
      # encountered the path before.

      # 9MB (+ padding below) so that it matches the 10MB expected by Gzip.
      chunk_size = 1024 * 1024 * 9

      for tar_info in tar_file:
        if not self.check_tar_info(tar_info):
          continue

        # Yield the tar header.
        if self.path_prefix:
          # Note: We use a copy here because we need to make sure we copy over all the internal
          # data of the tar header. We cannot use frombuf(tobuf()), however, because it doesn't
          # properly handle large filenames.
          clone = copy.deepcopy(tar_info)
          clone.name = os.path.join(self.path_prefix, clone.name)

          # If the entry is a *hard* link, then prefix it as well. Soft links are relative.
          if clone.linkname and clone.type == tarfile.LNKTYPE:
            clone.linkname = os.path.join(self.path_prefix, clone.linkname)

          yield clone.tobuf()
        else:
          yield tar_info.tobuf()

        # Try to extract any file contents for the tar. If found, we yield them as well.
        if tar_info.isreg():
          file_stream = tar_file.extractfile(tar_info)
          if file_stream is not None:
            length = 0
            while True:
              current_block = file_stream.read(chunk_size)
              if not len(current_block):
                break

              yield current_block
              length += len(current_block)

            file_stream.close()

            # Files must be padding to 512 byte multiples.
            if length % 512 != 0:
              yield '\0' * (512 - (length % 512))

      # Close the layer stream now that we're done with it.
      tar_file.close()

      # Conduct any post-tar work.
      self.after_tar_layer(current_tar)

    # Last two records are empty in TAR spec.
    yield '\0' * 512
    yield '\0' * 512


  def check_tar_info(self, tar_info):
    """ Returns true if the current tar_info should be added to the combined tar. False
        otherwise.
    """
    raise NotImplementedError()

  def after_tar_layer(self, current_tar):
    """ Invoked after a TAR layer is added, to do any post-add work. """
    raise NotImplementedError()