from datetime import datetime, timedelta

from data.database import QueueItem, db


MINIMUM_EXTENSION = timedelta(seconds=20)


class WorkQueue(object):
  def __init__(self, queue_name, transaction_factory,
               canonical_name_match_list=None, reporter=None):
    self._queue_name = queue_name
    self._reporter = reporter
    self._transaction_factory = transaction_factory
    self._currently_processing = False

    if canonical_name_match_list is None:
      self._canonical_name_match_list = []
    else:
      self._canonical_name_match_list = canonical_name_match_list

  @staticmethod
  def _canonical_name(name_list):
    return '/'.join(name_list) + '/'

  def _running_jobs(self, now, name_match_query):
    return (QueueItem
      .select(QueueItem.queue_name)
      .where(QueueItem.available == False,
             QueueItem.processing_expires > now,
             QueueItem.queue_name ** name_match_query))

  def _available_jobs(self, now, name_match_query, running_query):
    return (QueueItem
      .select()
      .where(QueueItem.queue_name ** name_match_query, QueueItem.available_after <= now,
             ((QueueItem.available == True) | (QueueItem.processing_expires <= now)),
             QueueItem.retries_remaining > 0, ~(QueueItem.queue_name << running_query)))

  def _name_match_query(self):
    return '%s%%' % self._canonical_name([self._queue_name] + self._canonical_name_match_list)

  def update_metrics(self):
    with self._transaction_factory(db):
      if self._reporter is None:
        return

      now = datetime.utcnow()
      name_match_query = self._name_match_query()

      running_query = self._running_jobs(now, name_match_query)
      running_count =running_query.distinct().count()

      avialable_query = self._available_jobs(now, name_match_query, running_query)
      available_count = avialable_query.select(QueueItem.queue_name).distinct().count()

      self._reporter(self._currently_processing, running_count, running_count + available_count)

  def put(self, canonical_name_list, message, available_after=0, retries_remaining=5):
    """
    Put an item, if it shouldn't be processed for some number of seconds,
    specify that amount as available_after.
    """

    params = {
      'queue_name': self._canonical_name([self._queue_name] + canonical_name_list),
      'body': message,
      'retries_remaining': retries_remaining,
    }

    available_date = datetime.utcnow() + timedelta(seconds=available_after or 0)
    params['available_after'] = available_date

    with self._transaction_factory(db):
      QueueItem.create(**params)

  def get(self, processing_time=300):
    """
    Get an available item and mark it as unavailable for the default of five
    minutes.
    """
    now = datetime.utcnow()

    name_match_query = self._name_match_query()

    with self._transaction_factory(db):
      running = self._running_jobs(now, name_match_query)
      avail = self._available_jobs(now, name_match_query, running)

      item = None
      try:
        item = avail.order_by(QueueItem.id).get()
        item.available = False
        item.processing_expires = now + timedelta(seconds=processing_time)
        item.retries_remaining -= 1
        item.save()

        self._currently_processing = True
      except QueueItem.DoesNotExist:
        self._currently_processing = False
        pass

      return item

  def complete(self, completed_item):
    with self._transaction_factory(db):
      completed_item.delete_instance()
      self._currently_processing = False

  def incomplete(self, incomplete_item, retry_after=300, restore_retry=False):
    with self._transaction_factory(db):
      retry_date = datetime.utcnow() + timedelta(seconds=retry_after)
      incomplete_item.available_after = retry_date
      incomplete_item.available = True

      if restore_retry:
        incomplete_item.retries_remaining += 1

      incomplete_item.save()
      self._currently_processing = False

  @staticmethod
  def extend_processing(queue_item, seconds_from_now):
    new_expiration = datetime.utcnow() + timedelta(seconds=seconds_from_now)

    # Only actually write the new expiration to the db if it moves the expiration some minimum
    if new_expiration - queue_item.processing_expires > MINIMUM_EXTENSION:
      queue_item.processing_expires = new_expiration
      queue_item.save()