67 lines
2.2 KiB
67 lines
2.2 KiB
import asyncio
import signal
from typing import Generic
from typing import TypeVar
from loguru import logger
from app.database import AsyncSession
from app.database import async_session
T = TypeVar("T")
class Worker(Generic[T]):
def __init__(self) -> None:
self._loop = asyncio.get_event_loop()
self._stop_event = asyncio.Event()
async def process_message(self, db_session: AsyncSession, message: T) -> None:
raise NotImplementedError
async def get_next_message(self, db_session: AsyncSession) -> T | None:
raise NotImplementedError
async def startup(self, db_session: AsyncSession) -> None:
return None
async def _main_loop(self, db_session: AsyncSession) -> None:
while not self._stop_event.is_set():
next_message = await self.get_next_message(db_session)
if next_message:
await self.process_message(db_session, next_message)
await asyncio.sleep(1)
async def _until_stopped(self) -> None:
await self._stop_event.wait()
async def run_forever(self) -> None:
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
for s in signals:
lambda s=s: asyncio.create_task(self._shutdown(s)),
async with async_session() as db_session:
await self.startup(db_session)
task = self._loop.create_task(self._main_loop(db_session))
stop_task = self._loop.create_task(self._until_stopped())
done, pending = await asyncio.wait(
{task, stop_task}, return_when=asyncio.FIRST_COMPLETED
logger.info(f"Waiting for tasks to finish {done=}/{pending=}")
await asyncio.sleep(5)
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
logger.info(f"Cancelling {len(tasks)} tasks")
[task.cancel() for task in tasks]
await asyncio.gather(*tasks, return_exceptions=True)
logger.info("stopping loop")
async def _shutdown(self, sig: signal.Signals) -> None:
logger.info(f"Caught {signal=}")