diff --git a/app/utils/workers.py b/app/utils/workers.py new file mode 100644 index 0000000..a817834 --- /dev/null +++ b/app/utils/workers.py @@ -0,0 +1,79 @@ +import asyncio +import signal +from typing import Generic +from typing import TypeVar + +from loguru import logger + +from app.database import AsyncSession +from app.database import async_session + +T = TypeVar("T") + + +class Worker(Generic[T]): + def __init__(self, workers_count: int) -> None: + self._loop = asyncio.get_event_loop() + self._in_flight: set[int] = set() + self._queue: asyncio.Queue[T] = asyncio.Queue(maxsize=1) + self._stop_event = asyncio.Event() + self._workers_count = workers_count + + async def _consumer(self, db_session: AsyncSession) -> None: + while not self._stop_event.is_set(): + message = await self._queue.get() + try: + await self.process_message(db_session, message) + finally: + self._in_flight.remove(message.id) # type: ignore + self._queue.task_done() + + async def _producer(self, db_session: AsyncSession) -> None: + while not self._stop_event.is_set(): + next_message = await self.get_next_message(db_session) + if next_message: + self._in_flight.add(next_message.id) # type: ignore + await self._queue.put(next_message) + else: + await asyncio.sleep(1) + + async def process_message(self, db_session: AsyncSession, message: T) -> None: + raise NotImplementedError + + async def get_next_message(self, db_session: AsyncSession) -> T | None: + raise NotImplementedError + + async def startup(self, db_session: AsyncSession) -> None: + return None + + def in_flight_ids(self) -> set[int]: + return self._in_flight + + async def run_forever(self) -> None: + signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) + for s in signals: + self._loop.add_signal_handler( + s, + lambda s=s: asyncio.create_task(self._shutdown(s)), + ) + + async with async_session() as db_session: + await self.startup(db_session) + self._loop.create_task(self._producer(db_session)) + for _ in range(self._workers_count): + self._loop.create_task(self._consumer(db_session)) + + await self._stop_event.wait() + logger.info("Waiting for tasks to finish") + await self._queue.join() + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] + logger.info(f"Cancelling {len(tasks)} tasks") + [task.cancel() for task in tasks] + + await asyncio.gather(*tasks, return_exceptions=True) + + logger.info("stopping loop") + + async def _shutdown(self, sig: signal.Signals) -> None: + logger.info(f"Caught {signal=}") + self._stop_event.set()