import marisa_trie
import os
import tarfile


AUFS_METADATA = u'.wh..wh.'

AUFS_WHITEOUT = u'.wh.'
AUFS_WHITEOUT_PREFIX_LENGTH = len(AUFS_WHITEOUT)

ALLOWED_TYPES = {tarfile.REGTYPE, tarfile.AREGTYPE}


def files_and_dirs_from_tar(source_stream, removed_prefix_collector):
  tar_stream = tarfile.open(mode='r|*', fileobj=source_stream)
  
  for tar_info in tar_stream:
    absolute = os.path.relpath(tar_info.name.decode('utf-8'), './')
    dirname = os.path.dirname(absolute)
    filename = os.path.basename(absolute)

    # Skip directories and metadata
    if (filename.startswith(AUFS_METADATA) or
        absolute.startswith(AUFS_METADATA)):
      # Skip
      continue

    elif filename.startswith(AUFS_WHITEOUT):
      removed_filename = filename[AUFS_WHITEOUT_PREFIX_LENGTH:]
      removed_prefix = os.path.join('/', dirname, removed_filename)
      removed_prefix_collector.add(removed_prefix)
      continue

    elif tar_info.type in ALLOWED_TYPES:
      yield '/' + absolute


def __compute_removed(base_trie, removed_prefixes):
  for prefix in removed_prefixes:
    for filename in base_trie.keys(prefix):
      yield filename


def __compute_added_changed(base_trie, delta_trie):
  added = set()
  changed = set()

  for filename in delta_trie.keys():
    if filename not in base_trie:
      added.add(filename)
    else:
      changed.add(filename)

  return added, changed


def __new_fs(base_trie, added, removed):
  for filename in base_trie.keys():
    if filename not in removed:
      yield filename

  for filename in added:
    yield filename


def empty_fs():
  return marisa_trie.Trie()


def compute_new_diffs_and_fs(base_trie, filename_source,
                             removed_prefix_collector):
  new_trie = marisa_trie.Trie(filename_source)
  (new_added, new_changed) = __compute_added_changed(base_trie, new_trie)

  new_removed = marisa_trie.Trie(__compute_removed(base_trie,
                                                   removed_prefix_collector))

  new_fs = marisa_trie.Trie(__new_fs(base_trie, new_added, new_removed))

  return (new_fs, new_added, new_changed, new_removed.keys())