78 lines
1.9 KiB
Python
78 lines
1.9 KiB
Python
|
import unittest
|
||
|
import json
|
||
|
import time
|
||
|
|
||
|
from functools import wraps
|
||
|
from threading import Thread, Lock
|
||
|
|
||
|
from app import app
|
||
|
from data.queue import WorkQueue
|
||
|
from initdb import wipe_database, initialize_database, populate_database
|
||
|
|
||
|
|
||
|
QUEUE_NAME = 'testqueuename'
|
||
|
|
||
|
class AutoUpdatingQueue(object):
|
||
|
def __init__(self, queue_to_wrap):
|
||
|
self._queue = queue_to_wrap
|
||
|
|
||
|
def _wrapper(self, func):
|
||
|
@wraps(func)
|
||
|
def wrapper(*args, **kwargs):
|
||
|
to_return = func(*args, **kwargs)
|
||
|
self._queue.update_metrics()
|
||
|
return to_return
|
||
|
return wrapper
|
||
|
|
||
|
def __getattr__(self, attr_name):
|
||
|
method_or_attr = getattr(self._queue, attr_name)
|
||
|
if callable(method_or_attr):
|
||
|
return self._wrapper(method_or_attr)
|
||
|
else:
|
||
|
return method_or_attr
|
||
|
|
||
|
|
||
|
class QueueTestCase(unittest.TestCase):
|
||
|
TEST_MESSAGE_1 = json.dumps({'data': 1})
|
||
|
|
||
|
def setUp(self):
|
||
|
self.transaction_factory = app.config['DB_TRANSACTION_FACTORY']
|
||
|
self.queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory))
|
||
|
wipe_database()
|
||
|
initialize_database()
|
||
|
populate_database()
|
||
|
|
||
|
|
||
|
class TestQueueThreads(QueueTestCase):
|
||
|
def test_queue_threads(self):
|
||
|
count = [20]
|
||
|
for i in range(count[0]):
|
||
|
self.queue.put([str(i)], self.TEST_MESSAGE_1)
|
||
|
|
||
|
lock = Lock()
|
||
|
def get(lock, count, queue):
|
||
|
item = queue.get()
|
||
|
if item is None:
|
||
|
return
|
||
|
self.assertEqual(self.TEST_MESSAGE_1, item.body)
|
||
|
with lock:
|
||
|
count[0] -= 1
|
||
|
|
||
|
threads = []
|
||
|
# The thread count needs to be a few times higher than the queue size
|
||
|
# count because some threads will get a None and thus won't decrement
|
||
|
# the counter.
|
||
|
for i in range(100):
|
||
|
t = Thread(target=get, args=(lock, count, self.queue))
|
||
|
threads.append(t)
|
||
|
for t in threads:
|
||
|
t.start()
|
||
|
for t in threads:
|
||
|
t.join()
|
||
|
self.assertEqual(count[0], 0)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
unittest.main()
|
||
|
|