Tweak incoming/outgoing workers

This commit is contained in:
Thomas Sileo 2022-08-11 12:24:17 +02:00
parent 0b6556e54a
commit 0696268d0b
5 changed files with 30 additions and 46 deletions

View file

@ -12,30 +12,9 @@ T = TypeVar("T")
class Worker(Generic[T]):
def __init__(self, workers_count: int) -> None:
def __init__(self) -> None:
self._loop = asyncio.get_event_loop()
self._in_flight: set[int] = set()
self._queue: asyncio.Queue[T] = asyncio.Queue(maxsize=1)
self._stop_event = asyncio.Event()
self._workers_count = workers_count
async def _consumer(self, db_session: AsyncSession) -> None:
while not self._stop_event.is_set():
message = await self._queue.get()
try:
await self.process_message(db_session, message)
finally:
self._in_flight.remove(message.id) # type: ignore
self._queue.task_done()
async def _producer(self, db_session: AsyncSession) -> None:
while not self._stop_event.is_set():
next_message = await self.get_next_message(db_session)
if next_message:
self._in_flight.add(next_message.id) # type: ignore
await self._queue.put(next_message)
else:
await asyncio.sleep(1)
async def process_message(self, db_session: AsyncSession, message: T) -> None:
raise NotImplementedError
@ -46,8 +25,16 @@ class Worker(Generic[T]):
async def startup(self, db_session: AsyncSession) -> None:
return None
def in_flight_ids(self) -> set[int]:
return self._in_flight
async def _main_loop(self, db_session: AsyncSession) -> None:
while not self._stop_event.is_set():
next_message = await self.get_next_message(db_session)
if next_message:
await self.process_message(db_session, next_message)
else:
await asyncio.sleep(1)
async def _until_stopped(self) -> None:
await self._stop_event.wait()
async def run_forever(self) -> None:
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
@ -59,13 +46,14 @@ class Worker(Generic[T]):
async with async_session() as db_session:
await self.startup(db_session)
self._loop.create_task(self._producer(db_session))
for _ in range(self._workers_count):
self._loop.create_task(self._consumer(db_session))
task = self._loop.create_task(self._main_loop(db_session))
stop_task = self._loop.create_task(self._until_stopped())
await self._stop_event.wait()
logger.info("Waiting for tasks to finish")
await self._queue.join()
done, pending = await asyncio.wait(
{task, stop_task}, return_when=asyncio.FIRST_COMPLETED
)
logger.info(f"Waiting for tasks to finish {done=}/{pending=}")
await asyncio.sleep(5)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
logger.info(f"Cancelling {len(tasks)} tasks")
[task.cancel() for task in tasks]