import asyncio
import signal
from typing import Generic
from typing import TypeVar

from loguru import logger

from app.database import AsyncSession
from app.database import async_session

T = TypeVar("T")


class Worker(Generic[T]):
    def __init__(self) -> None:
        self._loop = asyncio.get_event_loop()
        self._stop_event = asyncio.Event()

    async def process_message(self, db_session: AsyncSession, message: T) -> None:
        raise NotImplementedError

    async def get_next_message(self, db_session: AsyncSession) -> T | None:
        raise NotImplementedError

    async def startup(self, db_session: AsyncSession) -> None:
        return None

    async def _main_loop(self, db_session: AsyncSession) -> None:
        while not self._stop_event.is_set():
            next_message = await self.get_next_message(db_session)
            if next_message:
                await self.process_message(db_session, next_message)
                await asyncio.sleep(0.5)
            else:
                await asyncio.sleep(2)

    async def _until_stopped(self) -> None:
        await self._stop_event.wait()

    async def run_forever(self) -> None:
        signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
        for s in signals:
            self._loop.add_signal_handler(
                s,
                lambda s=s: asyncio.create_task(self._shutdown(s)),
            )

        async with async_session() as db_session:
            await self.startup(db_session)
            task = self._loop.create_task(self._main_loop(db_session))
            stop_task = self._loop.create_task(self._until_stopped())

            done, pending = await asyncio.wait(
                {task, stop_task}, return_when=asyncio.FIRST_COMPLETED
            )
            logger.info(f"Waiting for tasks to finish {done=}/{pending=}")
            tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
            logger.info(f"Cancelling {len(tasks)} tasks")
            [task.cancel() for task in tasks]

        try:
            await asyncio.wait_for(
                asyncio.gather(*tasks, return_exceptions=True),
                timeout=15,
            )
        except asyncio.TimeoutError:
            logger.info("Tasks failed to cancel")

        logger.info("stopping loop")

    async def _shutdown(self, sig: signal.Signals) -> None:
        logger.info(f"Caught {signal=}")
        self._stop_event.set()