diff --git a/tests/conftest.py b/tests/conftest.py index 21a8039..8985470 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,13 +4,11 @@ import pytest from fastapi.testclient import TestClient from app.database import Base +from app.database import SessionLocal from app.database import async_engine from app.database import async_session from app.database import engine from app.main import app -from tests.factories import _Session - -# _Session = orm.sessionmaker(bind=engine, autocommit=False, autoflush=False) @pytest.fixture @@ -26,14 +24,14 @@ async def async_db_session(): @pytest.fixture def db() -> Generator: Base.metadata.create_all(bind=engine) - # sess = orm.sessionmaker(bind=engine)() - yield _Session - # yield orm.scoped_session(orm.sessionmaker(bind=engine)) try: - Base.metadata.drop_all(bind=engine) - except Exception: - # XXX: for some reason, the teardown occasionally fails because of this - pass + yield SessionLocal() + finally: + try: + Base.metadata.drop_all(bind=engine) + except Exception: + # XXX: for some reason, the teardown occasionally fails because of this + pass @pytest.fixture diff --git a/tests/factories.py b/tests/factories.py index f4a290f..a40bca0 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -11,10 +11,10 @@ from app import actor from app import models from app.actor import RemoteActor from app.ap_object import RemoteObject -from app.database import engine +from app.database import SessionLocal from app.database import now -_Session = orm.scoped_session(orm.sessionmaker(bind=engine)) +_Session = orm.scoped_session(SessionLocal) def generate_key() -> tuple[str, str]: