import unittest import json import time from functools import wraps from threading import Thread, Lock from app import app from data.queue import WorkQueue from initdb import wipe_database, initialize_database, populate_database QUEUE_NAME = 'testqueuename' class AutoUpdatingQueue(object): def __init__(self, queue_to_wrap): self._queue = queue_to_wrap def _wrapper(self, func): @wraps(func) def wrapper(*args, **kwargs): to_return = func(*args, **kwargs) self._queue.update_metrics() return to_return return wrapper def __getattr__(self, attr_name): method_or_attr = getattr(self._queue, attr_name) if callable(method_or_attr): return self._wrapper(method_or_attr) else: return method_or_attr class QueueTestCase(unittest.TestCase): TEST_MESSAGE_1 = json.dumps({'data': 1}) def setUp(self): self.transaction_factory = app.config['DB_TRANSACTION_FACTORY'] self.queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory)) wipe_database() initialize_database() populate_database() class TestQueueThreads(QueueTestCase): def test_queue_threads(self): count = [20] for i in range(count[0]): self.queue.put([str(i)], self.TEST_MESSAGE_1) lock = Lock() def get(lock, count, queue): item = queue.get() if item is None: return self.assertEqual(self.TEST_MESSAGE_1, item.body) with lock: count[0] -= 1 threads = [] # The thread count needs to be a few times higher than the queue size # count because some threads will get a None and thus won't decrement # the counter. for i in range(100): t = Thread(target=get, args=(lock, count, self.queue)) threads.append(t) for t in threads: t.start() for t in threads: t.join() self.assertEqual(count[0], 0) if __name__ == '__main__': unittest.main()