import os
import shutil
import hashlib
import io
import logging
import psutil

from uuid import uuid4

from storage.basestorage import BaseStorageV2
from digest import digest_tools


logger = logging.getLogger(__name__)


class LocalStorage(BaseStorageV2):

  def __init__(self, storage_path):
    self._root_path = storage_path

  def _init_path(self, path=None, create=False):
    path = os.path.join(self._root_path, path) if path else self._root_path
    if create is True:
      dirname = os.path.dirname(path)
      if not os.path.exists(dirname):
        os.makedirs(dirname)
    return path

  def get_content(self, path):
    path = self._init_path(path)
    with open(path, mode='r') as f:
      return f.read()

  def put_content(self, path, content):
    path = self._init_path(path, create=True)
    with open(path, mode='w') as f:
      f.write(content)
    return path

  def stream_read(self, path):
    path = self._init_path(path)
    with open(path, mode='rb') as f:
      while True:
        buf = f.read(self.buffer_size)
        if not buf:
          break
        yield buf

  def stream_read_file(self, path):
    path = self._init_path(path)
    return io.open(path, mode='rb')

  def stream_write(self, path, fp, content_type=None, content_encoding=None):
    # Size is mandatory
    path = self._init_path(path, create=True)
    with open(path, mode='wb') as out_fp:
      self._stream_write_to_fp(fp, out_fp)

  def _stream_write_to_fp(self, in_fp, out_fp, num_bytes=-1):
    """ Copy the specified number of bytes from the input file stream to the output stream. If
        num_bytes < 0 copy until the stream ends.
    """
    bytes_copied = 0
    bytes_remaining = num_bytes
    while bytes_remaining > 0 or num_bytes < 0:
      try:
        buf = in_fp.read(self.buffer_size)
        if not buf:
          break
        out_fp.write(buf)
        bytes_copied += len(buf)
      except IOError:
        break

    return bytes_copied

  def list_directory(self, path=None):
    path = self._init_path(path)
    prefix = path[len(self._root_path) + 1:] + '/'
    exists = False
    for d in os.listdir(path):
      exists = True
      yield prefix + d
    if exists is False:
      # Raises OSError even when the directory is empty
      # (to be consistent with S3)
      raise OSError('No such directory: \'{0}\''.format(path))

  def exists(self, path):
    path = self._init_path(path)
    return os.path.exists(path)

  def remove(self, path):
    path = self._init_path(path)
    if os.path.isdir(path):
      shutil.rmtree(path)
      return
    try:
      os.remove(path)
    except OSError:
      pass

  def get_checksum(self, path):
    path = self._init_path(path)
    sha_hash = hashlib.sha256()
    with open(path, 'r') as to_hash:
      while True:
        buf = to_hash.read(self.buffer_size)
        if not buf:
          break
        sha_hash.update(buf)
    return sha_hash.hexdigest()[:7]


  def _rel_upload_path(self, uuid):
    return 'uploads/{0}'.format(uuid)


  def initiate_chunked_upload(self):
    new_uuid = str(uuid4())

    # Just create an empty file at the path
    with open(self._init_path(self._rel_upload_path(new_uuid), create=True), 'w'):
      pass

    return new_uuid

  def stream_upload_chunk(self, uuid, offset, length, in_fp):
    with open(self._init_path(self._rel_upload_path(uuid)), 'r+b') as upload_storage:
      upload_storage.seek(offset)
      return self._stream_write_to_fp(in_fp, upload_storage, length)

  def complete_chunked_upload(self, uuid, final_path, digest_to_verify):
    content_path = self._rel_upload_path(uuid)
    content_digest = digest_tools.sha256_digest_from_generator(self.stream_read(content_path))

    if not digest_tools.digests_equal(content_digest, digest_to_verify):
      msg = 'Given: {0} Computed: {1}'.format(digest_to_verify, content_digest)
      raise digest_tools.InvalidDigestException(msg)

    final_path_abs = self._init_path(final_path, create=True)
    if not self.exists(final_path_abs):
      logger.debug('Moving content into place at path: %s', final_path_abs)
      shutil.move(self._init_path(content_path), final_path_abs)
    else:
      logger.debug('Content already exists at path: %s', final_path_abs)

  def validate(self):
    # Load the set of disk mounts.
    try:
      mounts = psutil.disk_partitions(all=True)
    except:
      logger.exception('Could not load disk partitions')
      return

    # Verify that the storage's root path is under a mounted Docker volume.
    for mount in mounts:
      if mount.mountpoint != '/' and self._root_path.startswith(mount.mountpoint):
        if mount.device == 'tmpfs':
          return

    raise Exception('Storage path %s is not under a mounted volume.\n\n'
                    'Registry data must be stored under a mounted volume '
                    'to prevent data loss' % self._root_path)