diff --git a/data/queue.py b/data/queue.py index 3ce6c7a6f..3cf921c98 100644 --- a/data/queue.py +++ b/data/queue.py @@ -50,11 +50,12 @@ class WorkQueue(object): QueueItem.queue_name ** name_match_query)) def _available_jobs(self, now, name_match_query): - return (QueueItem - .select() - .where(QueueItem.queue_name ** name_match_query, QueueItem.available_after <= now, + return self._available_jobs_where(QueueItem.select(), now, name_match_query) + + def _available_jobs_where(self, query, now, name_match_query): + return query.where(QueueItem.queue_name ** name_match_query, QueueItem.available_after <= now, ((QueueItem.available == True) | (QueueItem.processing_expires <= now)), - QueueItem.retries_remaining > 0)) + QueueItem.retries_remaining > 0) def _available_jobs_not_running(self, now, name_match_query, running_query): return (self @@ -145,25 +146,30 @@ class WorkQueue(object): item = None try: - db_item_candidate = avail.order_by(QueueItem.id).get() - - with self._transaction_factory(db): - still_available_query = (db_for_update(self - ._available_jobs(now, name_match_query) - .where(QueueItem.id == db_item_candidate.id))) - - db_item = still_available_query.get() - db_item.available = False - db_item.processing_expires = now + timedelta(seconds=processing_time) - db_item.retries_remaining -= 1 - db_item.save() + # The previous solution to this used a select for update in a + # transaction to prevent multiple instances from processing the + # same queue item. This suffered performance problems. This solution + # instead has instances attempt to update the potential queue item to be + # unavailable. However, since their update clause is restricted to items + # that are available=False, only one instance's update will succeed, and + # it will have a changed row count of 1. Instances that have 0 changed + # rows know that another instance is already handling that item. + db_item = avail.order_by(QueueItem.id).get() + changed_query = (QueueItem.update( + available=False, + processing_expires=now + timedelta(seconds=processing_time), + retries_remaining=QueueItem.retries_remaining-1, + ) + .where(QueueItem.id == db_item.id)) + changed_query = self._available_jobs_where(changed_query, now, name_match_query) + changed = changed_query.execute() + if changed == 1: item = AttrDict({ 'id': db_item.id, 'body': db_item.body, - 'retries_remaining': db_item.retries_remaining + 'retries_remaining': db_item.retries_remaining - 1, }) - self._currently_processing = True except QueueItem.DoesNotExist: self._currently_processing = False diff --git a/local-test.sh b/local-test.sh index bab484c9c..edce17a61 100755 --- a/local-test.sh +++ b/local-test.sh @@ -5,3 +5,4 @@ export TROLLIUSDEBUG=1 python -m unittest discover -f python -m test.registry_tests -f +python -m test.queue_threads -f diff --git a/test/fulldbtest.sh b/test/fulldbtest.sh index 2e97ccd8b..68cb11660 100755 --- a/test/fulldbtest.sh +++ b/test/fulldbtest.sh @@ -36,6 +36,7 @@ down_postgres() { run_tests() { TEST_DATABASE_URI=$1 TEST=true python -m unittest discover -f + TEST_DATABASE_URI=$1 TEST=true python -m test.queue_threads -f } # NOTE: MySQL is currently broken on setup. diff --git a/test/queue_threads.py b/test/queue_threads.py new file mode 100644 index 000000000..0ed06109b --- /dev/null +++ b/test/queue_threads.py @@ -0,0 +1,77 @@ +import unittest +import json +import time + +from functools import wraps +from threading import Thread, Lock + +from app import app +from data.queue import WorkQueue +from initdb import wipe_database, initialize_database, populate_database + + +QUEUE_NAME = 'testqueuename' + +class AutoUpdatingQueue(object): + def __init__(self, queue_to_wrap): + self._queue = queue_to_wrap + + def _wrapper(self, func): + @wraps(func) + def wrapper(*args, **kwargs): + to_return = func(*args, **kwargs) + self._queue.update_metrics() + return to_return + return wrapper + + def __getattr__(self, attr_name): + method_or_attr = getattr(self._queue, attr_name) + if callable(method_or_attr): + return self._wrapper(method_or_attr) + else: + return method_or_attr + + +class QueueTestCase(unittest.TestCase): + TEST_MESSAGE_1 = json.dumps({'data': 1}) + + def setUp(self): + self.transaction_factory = app.config['DB_TRANSACTION_FACTORY'] + self.queue = AutoUpdatingQueue(WorkQueue(QUEUE_NAME, self.transaction_factory)) + wipe_database() + initialize_database() + populate_database() + + +class TestQueueThreads(QueueTestCase): + def test_queue_threads(self): + count = [20] + for i in range(count[0]): + self.queue.put([str(i)], self.TEST_MESSAGE_1) + + lock = Lock() + def get(lock, count, queue): + item = queue.get() + if item is None: + return + self.assertEqual(self.TEST_MESSAGE_1, item.body) + with lock: + count[0] -= 1 + + threads = [] + # The thread count needs to be a few times higher than the queue size + # count because some threads will get a None and thus won't decrement + # the counter. + for i in range(100): + t = Thread(target=get, args=(lock, count, self.queue)) + threads.append(t) + for t in threads: + t.start() + for t in threads: + t.join() + self.assertEqual(count[0], 0) + + +if __name__ == '__main__': + unittest.main() +