import unittest import json import time from functools import wraps from app import app from initdb import setup_database_for_testing, finished_database_for_testing from data.queue import WorkQueue QUEUE_NAME = 'testqueuename' class SaveLastCountReporter(object): def __init__(self): self.currently_processing = None self.running_count = None self.total = None def __call__(self, currently_processing, running_count, total_jobs): self.currently_processing = currently_processing self.running_count = running_count self.total = total_jobs class AutoUpdatingQueue(object): def __init__(self, queue_to_wrap): self._queue = queue_to_wrap def _wrapper(self, func): @wraps(func) def wrapper(*args, **kwargs): to_return = func(*args, **kwargs) self._queue.update_metrics() return to_return return wrapper def __getattr__(self, attr_name): method_or_attr = getattr(self._queue, attr_name) if callable(method_or_attr): return self._wrapper(method_or_attr) else: return method_or_attr class QueueTestCase(unittest.TestCase): TEST_MESSAGE_1 = json.dumps({'data': 1}) TEST_MESSAGE_2 = json.dumps({'data': 2}) TEST_MESSAGES = [json.dumps({'data': str(i)}) for i in range(1, 101)] def setUp(self): self.reporter = SaveLastCountReporter() self.transaction_factory = app.config['DB_TRANSACTION_FACTORY'] self.queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory, reporter=self.reporter)) setup_database_for_testing(self) def tearDown(self): finished_database_for_testing(self) class TestQueue(QueueTestCase): def test_same_canonical_names(self): self.assertEqual(self.reporter.currently_processing, None) self.assertEqual(self.reporter.running_count, None) self.assertEqual(self.reporter.total, None) self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'def'], self.TEST_MESSAGE_2) self.assertEqual(self.reporter.currently_processing, False) self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.total, 1) one = self.queue.get(ordering_required=True) self.assertNotEqual(None, one) self.assertEqual(self.TEST_MESSAGE_1, one.body) self.assertEqual(self.reporter.currently_processing, True) self.assertEqual(self.reporter.running_count, 1) self.assertEqual(self.reporter.total, 1) two_fail = self.queue.get(ordering_required=True) self.assertEqual(None, two_fail) self.assertEqual(self.reporter.running_count, 1) self.assertEqual(self.reporter.total, 1) self.queue.complete(one) self.assertEqual(self.reporter.currently_processing, False) self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.total, 1) two = self.queue.get(ordering_required=True) self.assertNotEqual(None, two) self.assertEqual(self.reporter.currently_processing, True) self.assertEqual(self.TEST_MESSAGE_2, two.body) self.assertEqual(self.reporter.running_count, 1) self.assertEqual(self.reporter.total, 1) def test_different_canonical_names(self): self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'ghi'], self.TEST_MESSAGE_2) self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.total, 2) one = self.queue.get(ordering_required=True) self.assertNotEqual(None, one) self.assertEqual(self.TEST_MESSAGE_1, one.body) self.assertEqual(self.reporter.running_count, 1) self.assertEqual(self.reporter.total, 2) two = self.queue.get(ordering_required=True) self.assertNotEqual(None, two) self.assertEqual(self.TEST_MESSAGE_2, two.body) self.assertEqual(self.reporter.running_count, 2) self.assertEqual(self.reporter.total, 2) def test_canonical_name(self): self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['abc', 'def', 'ghi'], self.TEST_MESSAGE_1) one = self.queue.get(ordering_required=True) self.assertNotEqual(QUEUE_NAME + '/abc/def/', one) two = self.queue.get(ordering_required=True) self.assertNotEqual(QUEUE_NAME + '/abc/def/ghi/', two) def test_expiration(self): self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.total, 1) one = self.queue.get(processing_time=0.5, ordering_required=True) self.assertNotEqual(None, one) self.assertEqual(self.reporter.running_count, 1) self.assertEqual(self.reporter.total, 1) one_fail = self.queue.get(ordering_required=True) self.assertEqual(None, one_fail) time.sleep(1) self.queue.update_metrics() self.assertEqual(self.reporter.running_count, 0) self.assertEqual(self.reporter.total, 1) one_again = self.queue.get(ordering_required=True) self.assertNotEqual(None, one_again) self.assertEqual(self.reporter.running_count, 1) self.assertEqual(self.reporter.total, 1) def test_specialized_queue(self): self.queue.put(['abc', 'def'], self.TEST_MESSAGE_1) self.queue.put(['def', 'def'], self.TEST_MESSAGE_2) my_queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory, ['def'])) two = my_queue.get(ordering_required=True) self.assertNotEqual(None, two) self.assertEqual(self.TEST_MESSAGE_2, two.body) one_fail = my_queue.get(ordering_required=True) self.assertEqual(None, one_fail) one = self.queue.get(ordering_required=True) self.assertNotEqual(None, one) self.assertEqual(self.TEST_MESSAGE_1, one.body) def test_random_queue_no_duplicates(self): for msg in self.TEST_MESSAGES: self.queue.put(['abc', 'def'], msg) seen = set() for _ in range(1, 101): item = self.queue.get() json_body = json.loads(item.body) msg = str(json_body['data']) self.assertTrue(msg not in seen) seen.add(msg) for body in self.TEST_MESSAGES: json_body = json.loads(body) msg = str(json_body['data']) self.assertIn(msg, seen) if __name__ == '__main__': unittest.main()