import unittest
import json
import time

from datetime import datetime, timedelta
from functools import wraps

from app import app
from initdb import setup_database_for_testing, finished_database_for_testing
from data.queue import WorkQueue
from datetime import timedelta


QUEUE_NAME = 'testqueuename'


class SaveLastCountReporter(object):
  def __init__(self):
    self.currently_processing = None
    self.running_count = None
    self.total = None

  def __call__(self, currently_processing, running_count, total_jobs):
    self.currently_processing = currently_processing
    self.running_count = running_count
    self.total = total_jobs


class AutoUpdatingQueue(object):
  def __init__(self, queue_to_wrap):
    self._queue = queue_to_wrap

  def _wrapper(self, func):
    @wraps(func)
    def wrapper(*args, **kwargs):
      to_return = func(*args, **kwargs)
      self._queue.update_metrics()
      return to_return
    return wrapper

  def __getattr__(self, attr_name):
    method_or_attr = getattr(self._queue, attr_name)
    if callable(method_or_attr):
      return self._wrapper(method_or_attr)
    else:
      return method_or_attr


class QueueTestCase(unittest.TestCase):
  TEST_MESSAGE_1 = json.dumps({'data': 1})
  TEST_MESSAGE_2 = json.dumps({'data': 2})
  TEST_MESSAGES = [json.dumps({'data': str(i)}) for i in range(1, 101)]

  def setUp(self):
    self.reporter = SaveLastCountReporter()
    self.transaction_factory = app.config['DB_TRANSACTION_FACTORY']
    self.queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory,
                                             reporter=self.reporter))
    setup_database_for_testing(self)

  def tearDown(self):
    finished_database_for_testing(self)


class TestQueue(QueueTestCase):
  def test_same_canonical_names(self):
    self.assertEqual(self.reporter.currently_processing, None)
    self.assertEqual(self.reporter.running_count, None)
    self.assertEqual(self.reporter.total, None)

    id_1 = int(self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1))
    id_2 = int(self.queue.put(['abc', 'def'], self.TEST_MESSAGE_2, available_after=-1))
    self.assertEqual(id_1 + 1, id_2)
    self.assertEqual(self.reporter.currently_processing, False)
    self.assertEqual(self.reporter.running_count, 0)
    self.assertEqual(self.reporter.total, 1)

    one = self.queue.get(ordering_required=True)
    self.assertNotEqual(None, one)
    self.assertEqual(self.TEST_MESSAGE_1, one.body)
    self.assertEqual(self.reporter.currently_processing, True)
    self.assertEqual(self.reporter.running_count, 1)
    self.assertEqual(self.reporter.total, 1)

    two_fail = self.queue.get(ordering_required=True)
    self.assertEqual(None, two_fail)
    self.assertEqual(self.reporter.running_count, 1)
    self.assertEqual(self.reporter.total, 1)

    self.queue.complete(one)
    self.assertEqual(self.reporter.currently_processing, False)
    self.assertEqual(self.reporter.running_count, 0)
    self.assertEqual(self.reporter.total, 1)

    two = self.queue.get(ordering_required=True)
    self.assertNotEqual(None, two)
    self.assertEqual(self.reporter.currently_processing, True)
    self.assertEqual(self.TEST_MESSAGE_2, two.body)
    self.assertEqual(self.reporter.running_count, 1)
    self.assertEqual(self.reporter.total, 1)

  def test_different_canonical_names(self):
    self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
    self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2, available_after=-1)
    self.assertEqual(self.reporter.running_count, 0)
    self.assertEqual(self.reporter.total, 2)

    one = self.queue.get(ordering_required=True)
    self.assertNotEqual(None, one)
    self.assertEqual(self.TEST_MESSAGE_1, one.body)
    self.assertEqual(self.reporter.running_count, 1)
    self.assertEqual(self.reporter.total, 2)

    two = self.queue.get(ordering_required=True)
    self.assertNotEqual(None, two)
    self.assertEqual(self.TEST_MESSAGE_2, two.body)
    self.assertEqual(self.reporter.running_count, 2)
    self.assertEqual(self.reporter.total, 2)

  def test_canonical_name(self):
    self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
    self.queue.put(['abc', 'def', 'ghi'], self.TEST_MESSAGE_1, available_after=-1)

    one = self.queue.get(ordering_required=True)
    self.assertNotEqual(QUEUE_NAME + '/abc/def/', one)

    two = self.queue.get(ordering_required=True)
    self.assertNotEqual(QUEUE_NAME + '/abc/def/ghi/', two)

  def test_expiration(self):
    self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
    self.assertEqual(self.reporter.running_count, 0)
    self.assertEqual(self.reporter.total, 1)

    one = self.queue.get(processing_time=0.5, ordering_required=True)
    self.assertNotEqual(None, one)
    self.assertEqual(self.reporter.running_count, 1)
    self.assertEqual(self.reporter.total, 1)

    one_fail = self.queue.get(ordering_required=True)
    self.assertEqual(None, one_fail)

    time.sleep(1)
    self.queue.update_metrics()
    self.assertEqual(self.reporter.running_count, 0)
    self.assertEqual(self.reporter.total, 1)

    one_again = self.queue.get(ordering_required=True)
    self.assertNotEqual(None, one_again)
    self.assertEqual(self.reporter.running_count, 1)
    self.assertEqual(self.reporter.total, 1)

  def test_alive(self):
    # No queue item = not alive.
    self.assertFalse(self.queue.alive(['abc', 'def']))

    # Add a queue item.
    self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
    self.assertTrue(self.queue.alive(['abc', 'def']))

    # Retrieve the queue item.
    queue_item = self.queue.get()
    self.assertIsNotNone(queue_item)
    self.assertTrue(self.queue.alive(['abc', 'def']))

    # Make sure it is running by trying to retrieve it again.
    self.assertIsNone(self.queue.get())

    # Delete the queue item.
    self.queue.complete(queue_item)
    self.assertFalse(self.queue.alive(['abc', 'def']))

  def test_specialized_queue(self):
    self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
    self.queue.put(['def', 'def'], self.TEST_MESSAGE_2, available_after=-1)

    my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory, ['def']))

    two = my_queue.get(ordering_required=True)
    self.assertNotEqual(None, two)
    self.assertEqual(self.TEST_MESSAGE_2, two.body)

    one_fail = my_queue.get(ordering_required=True)
    self.assertEqual(None, one_fail)

    one = self.queue.get(ordering_required=True)
    self.assertNotEqual(None, one)
    self.assertEqual(self.TEST_MESSAGE_1, one.body)

  def test_random_queue_no_duplicates(self):
    for msg in self.TEST_MESSAGES:
      self.queue.put(['abc', 'def'], msg, available_after=-1)
    seen = set()

    for _ in range(1, 101):
      item = self.queue.get()
      json_body = json.loads(item.body)
      msg = str(json_body['data'])
      self.assertTrue(msg not in seen)
      seen.add(msg)

    for body in self.TEST_MESSAGES:
      json_body = json.loads(body)
      msg = str(json_body['data'])
      self.assertIn(msg, seen)

  def test_bulk_insert(self):
    self.assertEqual(self.reporter.currently_processing, None)
    self.assertEqual(self.reporter.running_count, None)
    self.assertEqual(self.reporter.total, None)

    with self.queue.batch_insert() as queue_put:
      queue_put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-1)
      queue_put(['abc', 'def'], self.TEST_MESSAGE_2, available_after=-1)

    self.queue.update_metrics()
    self.assertEqual(self.reporter.currently_processing, False)
    self.assertEqual(self.reporter.running_count, 0)
    self.assertEqual(self.reporter.total, 1)

    with self.queue.batch_insert() as queue_put:
      queue_put(['abd', 'def'], self.TEST_MESSAGE_1, available_after=-1)
      queue_put(['abd', 'ghi'], self.TEST_MESSAGE_2, available_after=-1)

    self.queue.update_metrics()
    self.assertEqual(self.reporter.currently_processing, False)
    self.assertEqual(self.reporter.running_count, 0)
    self.assertEqual(self.reporter.total, 3)

  def test_num_available_between(self):
    now = datetime.utcnow()
    self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1, available_after=-10)
    self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2, available_after=-5)

    # Partial results
    count = self.queue.num_available_jobs_between(now-timedelta(seconds=8), now, ['abc'])
    self.assertEqual(1, count)

    # All results
    count = self.queue.num_available_jobs_between(now-timedelta(seconds=20), now, ['/abc'])
    self.assertEqual(2, count)

    # No results
    count = self.queue.num_available_jobs_between(now, now, 'abc')
    self.assertEqual(0, count)

if __name__ == '__main__':
  unittest.main()