From 1f54a6a6ac1b3f5c7dbc30732df28e66717214fa Mon Sep 17 00:00:00 2001 From: Thomas Sileo Date: Wed, 29 Jun 2022 20:43:17 +0200 Subject: [PATCH] Switch to aiosqlite --- app/actor.py | 40 +-- app/admin.py | 285 +++++++++--------- app/boxes.py | 296 +++++++++++-------- app/config.py | 6 +- app/database.py | 19 +- app/lookup.py | 6 +- app/main.py | 342 ++++++++++++---------- app/outgoing_activities.py | 20 +- app/source.py | 26 +- app/templates.py | 16 +- app/uploads.py | 14 +- data/tests.toml | 4 +- poetry.lock | 103 ++++--- pyproject.toml | 4 +- tests/conftest.py | 19 +- tests/test_actor.py | 21 +- tests/test_emoji.py | 2 +- tests/test_inbox.py | 2 +- tests/test_outbox.py | 2 +- tests/test_process_outgoing_activities.py | 18 +- tests/test_public.py | 2 +- 21 files changed, 698 insertions(+), 549 deletions(-) diff --git a/app/actor.py b/app/actor.py index b906e2b..69789b6 100644 --- a/app/actor.py +++ b/app/actor.py @@ -4,11 +4,11 @@ from typing import Union from urllib.parse import urlparse from sqlalchemy import select -from sqlalchemy.orm import Session from sqlalchemy.orm import joinedload from app import activitypub as ap from app import media +from app.database import AsyncSession if typing.TYPE_CHECKING: from app.models import Actor as ActorModel @@ -131,7 +131,7 @@ class RemoteActor(Actor): LOCAL_ACTOR = RemoteActor(ap_actor=ap.ME) -def save_actor(db: Session, ap_actor: ap.RawObject) -> "ActorModel": +async def save_actor(db_session: AsyncSession, ap_actor: ap.RawObject) -> "ActorModel": from app import models if ap_type := ap_actor.get("type") not in ap.ACTOR_TYPES: @@ -143,23 +143,25 @@ def save_actor(db: Session, ap_actor: ap.RawObject) -> "ActorModel": ap_type=ap_actor["type"], handle=_handle(ap_actor), ) - db.add(actor) - db.commit() - db.refresh(actor) + db_session.add(actor) + await db_session.commit() + await db_session.refresh(actor) return actor -def fetch_actor(db: Session, actor_id: str) -> "ActorModel": +async def fetch_actor(db_session: AsyncSession, actor_id: str) -> "ActorModel": from app import models - existing_actor = db.execute( - select(models.Actor).where(models.Actor.ap_id == actor_id) - ).scalar_one_or_none() + existing_actor = ( + await db_session.scalars( + select(models.Actor).where(models.Actor.ap_id == actor_id) + ) + ).one_or_none() if existing_actor: return existing_actor ap_actor = ap.get(actor_id) - return save_actor(db, ap_actor) + return await save_actor(db_session, ap_actor) @dataclass @@ -175,8 +177,8 @@ class ActorMetadata: ActorsMetadata = dict[str, ActorMetadata] -def get_actors_metadata( - db: Session, +async def get_actors_metadata( + db_session: AsyncSession, actors: list[Union["ActorModel", "RemoteActor"]], ) -> ActorsMetadata: from app import models @@ -184,17 +186,19 @@ def get_actors_metadata( ap_actor_ids = [actor.ap_id for actor in actors] followers = { follower.ap_actor_id: follower.inbox_object.ap_id - for follower in db.scalars( - select(models.Follower) - .where(models.Follower.ap_actor_id.in_(ap_actor_ids)) - .options(joinedload(models.Follower.inbox_object)) + for follower in ( + await db_session.scalars( + select(models.Follower) + .where(models.Follower.ap_actor_id.in_(ap_actor_ids)) + .options(joinedload(models.Follower.inbox_object)) + ) ) .unique() .all() } following = { following.ap_actor_id - for following in db.execute( + for following in await db_session.execute( select(models.Following.ap_actor_id).where( models.Following.ap_actor_id.in_(ap_actor_ids) ) @@ -202,7 +206,7 @@ def get_actors_metadata( } sent_follow_requests = { follow_req.ap_object["object"]: follow_req.ap_id - for follow_req in db.execute( + for follow_req in await db_session.execute( select(models.OutboxObject.ap_object, models.OutboxObject.ap_id).where( models.OutboxObject.ap_type == "Follow", models.OutboxObject.undone_by_outbox_object_id.is_(None), diff --git a/app/admin.py b/app/admin.py index 385c123..49fa650 100644 --- a/app/admin.py +++ b/app/admin.py @@ -8,7 +8,6 @@ from fastapi.exceptions import HTTPException from fastapi.responses import RedirectResponse from sqlalchemy import func from sqlalchemy import select -from sqlalchemy.orm import Session from sqlalchemy.orm import joinedload from app import activitypub as ap @@ -25,7 +24,8 @@ from app.config import generate_csrf_token from app.config import session_serializer from app.config import verify_csrf_token from app.config import verify_password -from app.database import get_db +from app.database import AsyncSession +from app.database import get_db_session from app.lookup import lookup from app.uploads import save_upload from app.utils import pagination @@ -62,30 +62,36 @@ unauthenticated_router = APIRouter() @router.get("/") -def admin_index( +async def admin_index( request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> templates.TemplateResponse: - return templates.render_template(db, request, "index.html", {"request": request}) + return await templates.render_template( + db_session, request, "index.html", {"request": request} + ) @router.get("/lookup") -def get_lookup( +async def get_lookup( request: Request, query: str | None = None, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> templates.TemplateResponse: ap_object = None actors_metadata = {} if query: - ap_object = lookup(db, query) + ap_object = await lookup(db_session, query) if ap_object.ap_type in ap.ACTOR_TYPES: - actors_metadata = get_actors_metadata(db, [ap_object]) # type: ignore + actors_metadata = await get_actors_metadata( + db_session, [ap_object] # type: ignore + ) else: - actors_metadata = get_actors_metadata(db, [ap_object.actor]) # type: ignore + actors_metadata = await get_actors_metadata( + db_session, [ap_object.actor] # type: ignore + ) print(ap_object) - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "lookup.html", { @@ -97,16 +103,18 @@ def get_lookup( @router.get("/new") -def admin_new( +async def admin_new( request: Request, query: str | None = None, in_reply_to: str | None = None, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> templates.TemplateResponse: content = "" in_reply_to_object = None if in_reply_to: - in_reply_to_object = boxes.get_anybox_object_by_ap_id(db, in_reply_to) + in_reply_to_object = await boxes.get_anybox_object_by_ap_id( + db_session, in_reply_to + ) # Add mentions to the initial note content if not in_reply_to_object: @@ -117,8 +125,8 @@ def admin_new( if tag.get("type") == "Mention" and tag["name"] != LOCAL_ACTOR.handle: content += f'{tag["name"]} ' - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "admin_new.html", { @@ -138,28 +146,30 @@ def admin_new( @router.get("/bookmarks") -def admin_bookmarks( +async def admin_bookmarks( request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> templates.TemplateResponse: stream = ( - db.scalars( - select(models.InboxObject) - .where( - models.InboxObject.ap_type.in_( - ["Note", "Article", "Video", "Announce"] - ), - models.InboxObject.is_hidden_from_stream.is_(False), - models.InboxObject.undone_by_inbox_object_id.is_(None), - models.InboxObject.is_bookmarked.is_(True), + ( + await db_session.scalars( + select(models.InboxObject) + .where( + models.InboxObject.ap_type.in_( + ["Note", "Article", "Video", "Announce"] + ), + models.InboxObject.is_hidden_from_stream.is_(False), + models.InboxObject.undone_by_inbox_object_id.is_(None), + models.InboxObject.is_bookmarked.is_(True), + ) + .order_by(models.InboxObject.ap_published_at.desc()) + .limit(20) ) - .order_by(models.InboxObject.ap_published_at.desc()) - .limit(20) ).all() # TODO: joinedload + unique ) - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "admin_stream.html", { @@ -169,9 +179,9 @@ def admin_bookmarks( @router.get("/inbox") -def admin_inbox( +async def admin_inbox( request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), filter_by: str | None = None, cursor: str | None = None, ) -> templates.TemplateResponse: @@ -184,17 +194,22 @@ def admin_inbox( ) page_size = 20 - remaining_count = db.scalar(select(func.count(models.InboxObject.id)).where(*where)) + remaining_count = await db_session.scalar( + select(func.count(models.InboxObject.id)).where(*where) + ) q = select(models.InboxObject).where(*where) inbox = ( - db.scalars( - q.options( - joinedload(models.InboxObject.relates_to_inbox_object), - joinedload(models.InboxObject.relates_to_outbox_object), + ( + await db_session.scalars( + q.options( + joinedload(models.InboxObject.relates_to_inbox_object), + joinedload(models.InboxObject.relates_to_outbox_object), + joinedload(models.InboxObject.actor), + ) + .order_by(models.InboxObject.ap_published_at.desc()) + .limit(20) ) - .order_by(models.InboxObject.ap_published_at.desc()) - .limit(20) ) .unique() .all() @@ -206,8 +221,8 @@ def admin_inbox( else None ) - actors_metadata = get_actors_metadata( - db, + actors_metadata = await get_actors_metadata( + db_session, [ inbox_object.actor for inbox_object in inbox @@ -215,8 +230,8 @@ def admin_inbox( ], ) - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "admin_inbox.html", { @@ -228,9 +243,9 @@ def admin_inbox( @router.get("/outbox") -def admin_outbox( +async def admin_outbox( request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), filter_by: str | None = None, cursor: str | None = None, ) -> templates.TemplateResponse: @@ -243,20 +258,22 @@ def admin_outbox( ) page_size = 20 - remaining_count = db.scalar( + remaining_count = await db_session.scalar( select(func.count(models.OutboxObject.id)).where(*where) ) q = select(models.OutboxObject).where(*where) outbox = ( - db.scalars( - q.options( - joinedload(models.OutboxObject.relates_to_inbox_object), - joinedload(models.OutboxObject.relates_to_outbox_object), - joinedload(models.OutboxObject.relates_to_actor), + ( + await db_session.scalars( + q.options( + joinedload(models.OutboxObject.relates_to_inbox_object), + joinedload(models.OutboxObject.relates_to_outbox_object), + joinedload(models.OutboxObject.relates_to_actor), + ) + .order_by(models.OutboxObject.ap_published_at.desc()) + .limit(page_size) ) - .order_by(models.OutboxObject.ap_published_at.desc()) - .limit(page_size) ) .unique() .all() @@ -268,8 +285,8 @@ def admin_outbox( else None ) - actors_metadata = get_actors_metadata( - db, + actors_metadata = await get_actors_metadata( + db_session, [ outbox_object.relates_to_actor for outbox_object in outbox @@ -277,8 +294,8 @@ def admin_outbox( ], ) - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "admin_outbox.html", { @@ -290,32 +307,34 @@ def admin_outbox( @router.get("/notifications") -def get_notifications( - request: Request, db: Session = Depends(get_db) +async def get_notifications( + request: Request, db_session: AsyncSession = Depends(get_db_session) ) -> templates.TemplateResponse: notifications = ( - db.scalars( - select(models.Notification) - .options( - joinedload(models.Notification.actor), - joinedload(models.Notification.inbox_object), - joinedload(models.Notification.outbox_object), + ( + await db_session.scalars( + select(models.Notification) + .options( + joinedload(models.Notification.actor), + joinedload(models.Notification.inbox_object), + joinedload(models.Notification.outbox_object), + ) + .order_by(models.Notification.created_at.desc()) ) - .order_by(models.Notification.created_at.desc()) ) .unique() .all() ) - actors_metadata = get_actors_metadata( - db, [notif.actor for notif in notifications if notif.actor] + actors_metadata = await get_actors_metadata( + db_session, [notif.actor for notif in notifications if notif.actor] ) for notif in notifications: notif.is_new = False - db.commit() + await db_session.commit() - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "notifications.html", { @@ -326,19 +345,19 @@ def get_notifications( @router.get("/object") -def admin_object( +async def admin_object( request: Request, ap_id: str, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> templates.TemplateResponse: - requested_object = boxes.get_anybox_object_by_ap_id(db, ap_id) + requested_object = await boxes.get_anybox_object_by_ap_id(db_session, ap_id) if not requested_object: raise HTTPException(status_code=404) - replies_tree = boxes.get_replies_tree(db, requested_object) + replies_tree = await boxes.get_replies_tree(db_session, requested_object) - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "object.html", {"replies_tree": replies_tree}, @@ -346,30 +365,34 @@ def admin_object( @router.get("/profile") -def admin_profile( +async def admin_profile( request: Request, actor_id: str, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> templates.TemplateResponse: - actor = db.execute( - select(models.Actor).where(models.Actor.ap_id == actor_id) + actor = ( + await db_session.execute( + select(models.Actor).where(models.Actor.ap_id == actor_id) + ) ).scalar_one_or_none() if not actor: raise HTTPException(status_code=404) - actors_metadata = get_actors_metadata(db, [actor]) + actors_metadata = await get_actors_metadata(db_session, [actor]) - inbox_objects = db.scalars( - select(models.InboxObject) - .where( - models.InboxObject.actor_id == actor.id, - models.InboxObject.ap_type.in_(["Note", "Article", "Video"]), + inbox_objects = ( + await db_session.scalars( + select(models.InboxObject) + .where( + models.InboxObject.actor_id == actor.id, + models.InboxObject.ap_type.in_(["Note", "Article", "Video"]), + ) + .order_by(models.InboxObject.ap_published_at.desc()) ) - .order_by(models.InboxObject.ap_published_at.desc()) ).all() - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "admin_profile.html", { @@ -381,120 +404,120 @@ def admin_profile( @router.post("/actions/follow") -def admin_actions_follow( +async def admin_actions_follow( request: Request, ap_actor_id: str = Form(), redirect_url: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: print(f"Following {ap_actor_id}") - send_follow(db, ap_actor_id) + await send_follow(db_session, ap_actor_id) return RedirectResponse(redirect_url, status_code=302) @router.post("/actions/like") -def admin_actions_like( +async def admin_actions_like( request: Request, ap_object_id: str = Form(), redirect_url: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: - boxes.send_like(db, ap_object_id) + await boxes.send_like(db_session, ap_object_id) return RedirectResponse(redirect_url, status_code=302) @router.post("/actions/undo") -def admin_actions_undo( +async def admin_actions_undo( request: Request, ap_object_id: str = Form(), redirect_url: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: - boxes.send_undo(db, ap_object_id) + await boxes.send_undo(db_session, ap_object_id) return RedirectResponse(redirect_url, status_code=302) @router.post("/actions/announce") -def admin_actions_announce( +async def admin_actions_announce( request: Request, ap_object_id: str = Form(), redirect_url: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: - boxes.send_announce(db, ap_object_id) + await boxes.send_announce(db_session, ap_object_id) return RedirectResponse(redirect_url, status_code=302) @router.post("/actions/bookmark") -def admin_actions_bookmark( +async def admin_actions_bookmark( request: Request, ap_object_id: str = Form(), redirect_url: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: - inbox_object = get_inbox_object_by_ap_id(db, ap_object_id) + inbox_object = await get_inbox_object_by_ap_id(db_session, ap_object_id) if not inbox_object: raise ValueError("Should never happen") inbox_object.is_bookmarked = True - db.commit() + await db_session.commit() return RedirectResponse(redirect_url, status_code=302) @router.post("/actions/unbookmark") -def admin_actions_unbookmark( +async def admin_actions_unbookmark( request: Request, ap_object_id: str = Form(), redirect_url: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: - inbox_object = get_inbox_object_by_ap_id(db, ap_object_id) + inbox_object = await get_inbox_object_by_ap_id(db_session, ap_object_id) if not inbox_object: raise ValueError("Should never happen") inbox_object.is_bookmarked = False - db.commit() + await db_session.commit() return RedirectResponse(redirect_url, status_code=302) @router.post("/actions/pin") -def admin_actions_pin( +async def admin_actions_pin( request: Request, ap_object_id: str = Form(), redirect_url: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: - outbox_object = get_outbox_object_by_ap_id(db, ap_object_id) + outbox_object = await get_outbox_object_by_ap_id(db_session, ap_object_id) if not outbox_object: raise ValueError("Should never happen") outbox_object.is_pinned = True - db.commit() + await db_session.commit() return RedirectResponse(redirect_url, status_code=302) @router.post("/actions/unpin") -def admin_actions_unpin( +async def admin_actions_unpin( request: Request, ap_object_id: str = Form(), redirect_url: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: - outbox_object = get_outbox_object_by_ap_id(db, ap_object_id) + outbox_object = await get_outbox_object_by_ap_id(db_session, ap_object_id) if not outbox_object: raise ValueError("Should never happen") outbox_object.is_pinned = False - db.commit() + await db_session.commit() return RedirectResponse(redirect_url, status_code=302) @router.post("/actions/new") -def admin_actions_new( +async def admin_actions_new( request: Request, files: list[UploadFile] = [], content: str = Form(), @@ -504,16 +527,16 @@ def admin_actions_new( is_sensitive: bool = Form(False), visibility: str = Form(), csrf_check: None = Depends(verify_csrf_token), - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> RedirectResponse: # XXX: for some reason, no files restuls in an empty single file uploads = [] if len(files) >= 1 and files[0].filename: for f in files: - upload = save_upload(db, f) + upload = await save_upload(db_session, f) uploads.append((upload, f.filename)) - public_id = boxes.send_create( - db, + public_id = await boxes.send_create( + db_session, source=content, uploads=uploads, in_reply_to=in_reply_to or None, @@ -528,12 +551,12 @@ def admin_actions_new( @unauthenticated_router.get("/login") -def login( +async def login( request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> templates.TemplateResponse: - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "login.html", {"csrf_token": generate_csrf_token()}, @@ -541,7 +564,7 @@ def login( @unauthenticated_router.post("/login") -def login_validation( +async def login_validation( request: Request, password: str = Form(), csrf_check: None = Depends(verify_csrf_token), @@ -556,7 +579,7 @@ def login_validation( @router.get("/logout") -def logout( +async def logout( request: Request, ) -> RedirectResponse: resp = RedirectResponse(request.url_for("index"), status_code=302) diff --git a/app/boxes.py b/app/boxes.py index e8590d7..f68f5cb 100644 --- a/app/boxes.py +++ b/app/boxes.py @@ -12,7 +12,6 @@ from sqlalchemy import func from sqlalchemy import select from sqlalchemy import update from sqlalchemy.exc import IntegrityError -from sqlalchemy.orm import Session from sqlalchemy.orm import joinedload from app import activitypub as ap @@ -26,6 +25,7 @@ from app.actor import save_actor from app.ap_object import RemoteObject from app.config import BASE_URL from app.config import ID +from app.database import AsyncSession from app.database import now from app.outgoing_activities import new_outgoing_activity from app.source import markdownify @@ -42,8 +42,8 @@ def outbox_object_id(outbox_id) -> str: return f"{BASE_URL}/o/{outbox_id}" -def save_outbox_object( - db: Session, +async def save_outbox_object( + db_session: AsyncSession, public_id: str, raw_object: ap.RawObject, relates_to_inbox_object_id: int | None = None, @@ -68,15 +68,15 @@ def save_outbox_object( is_hidden_from_homepage=True if ra.in_reply_to else False, source=source, ) - db.add(outbox_object) - db.commit() - db.refresh(outbox_object) + db_session.add(outbox_object) + await db_session.commit() + await db_session.refresh(outbox_object) return outbox_object -def send_like(db: Session, ap_object_id: str) -> None: - inbox_object = get_inbox_object_by_ap_id(db, ap_object_id) +async def send_like(db_session: AsyncSession, ap_object_id: str) -> None: + inbox_object = await get_inbox_object_by_ap_id(db_session, ap_object_id) if not inbox_object: raise ValueError(f"{ap_object_id} not found in the inbox") @@ -88,20 +88,22 @@ def send_like(db: Session, ap_object_id: str) -> None: "actor": ID, "object": ap_object_id, } - outbox_object = save_outbox_object( - db, like_id, like, relates_to_inbox_object_id=inbox_object.id + outbox_object = await save_outbox_object( + db_session, like_id, like, relates_to_inbox_object_id=inbox_object.id ) if not outbox_object.id: raise ValueError("Should never happen") inbox_object.liked_via_outbox_object_ap_id = outbox_object.ap_id - db.commit() + await db_session.commit() - new_outgoing_activity(db, inbox_object.actor.inbox_url, outbox_object.id) + await new_outgoing_activity( + db_session, inbox_object.actor.inbox_url, outbox_object.id + ) -def send_announce(db: Session, ap_object_id: str) -> None: - inbox_object = get_inbox_object_by_ap_id(db, ap_object_id) +async def send_announce(db_session: AsyncSession, ap_object_id: str) -> None: + inbox_object = await get_inbox_object_by_ap_id(db_session, ap_object_id) if not inbox_object: raise ValueError(f"{ap_object_id} not found in the inbox") @@ -118,22 +120,22 @@ def send_announce(db: Session, ap_object_id: str) -> None: inbox_object.ap_actor_id, ], } - outbox_object = save_outbox_object( - db, announce_id, announce, relates_to_inbox_object_id=inbox_object.id + outbox_object = await save_outbox_object( + db_session, announce_id, announce, relates_to_inbox_object_id=inbox_object.id ) if not outbox_object.id: raise ValueError("Should never happen") inbox_object.announced_via_outbox_object_ap_id = outbox_object.ap_id - db.commit() + await db_session.commit() - recipients = _compute_recipients(db, announce) + recipients = await _compute_recipients(db_session, announce) for rcp in recipients: - new_outgoing_activity(db, rcp, outbox_object.id) + await new_outgoing_activity(db_session, rcp, outbox_object.id) -def send_follow(db: Session, ap_actor_id: str) -> None: - actor = fetch_actor(db, ap_actor_id) +async def send_follow(db_session: AsyncSession, ap_actor_id: str) -> None: + actor = await fetch_actor(db_session, ap_actor_id) follow_id = allocate_outbox_id() follow = { @@ -144,17 +146,17 @@ def send_follow(db: Session, ap_actor_id: str) -> None: "object": ap_actor_id, } - outbox_object = save_outbox_object( - db, follow_id, follow, relates_to_actor_id=actor.id + outbox_object = await save_outbox_object( + db_session, follow_id, follow, relates_to_actor_id=actor.id ) if not outbox_object.id: raise ValueError("Should never happen") - new_outgoing_activity(db, actor.inbox_url, outbox_object.id) + await new_outgoing_activity(db_session, actor.inbox_url, outbox_object.id) -def send_undo(db: Session, ap_object_id: str) -> None: - outbox_object_to_undo = get_outbox_object_by_ap_id(db, ap_object_id) +async def send_undo(db_session: AsyncSession, ap_object_id: str) -> None: + outbox_object_to_undo = await get_outbox_object_by_ap_id(db_session, ap_object_id) if not outbox_object_to_undo: raise ValueError(f"{ap_object_id} not found in the outbox") @@ -172,8 +174,8 @@ def send_undo(db: Session, ap_object_id: str) -> None: "object": ap.remove_context(outbox_object_to_undo.ap_object), } - outbox_object = save_outbox_object( - db, + outbox_object = await save_outbox_object( + db_session, undo_id, undo, relates_to_outbox_object_id=outbox_object_to_undo.id, @@ -186,31 +188,33 @@ def send_undo(db: Session, ap_object_id: str) -> None: if outbox_object_to_undo.ap_type == "Follow": if not outbox_object_to_undo.activity_object_ap_id: raise ValueError("Should never happen") - followed_actor = fetch_actor(db, outbox_object_to_undo.activity_object_ap_id) - new_outgoing_activity( - db, + followed_actor = await fetch_actor( + db_session, outbox_object_to_undo.activity_object_ap_id + ) + await new_outgoing_activity( + db_session, followed_actor.inbox_url, outbox_object.id, ) # Also remove the follow from the following collection - db.execute( + await db_session.execute( delete(models.Following).where( models.Following.ap_actor_id == followed_actor.ap_id ) ) - db.commit() + await db_session.commit() elif outbox_object_to_undo.ap_type == "Like": liked_object_ap_id = outbox_object_to_undo.activity_object_ap_id if not liked_object_ap_id: raise ValueError("Should never happen") - liked_object = get_inbox_object_by_ap_id(db, liked_object_ap_id) + liked_object = await get_inbox_object_by_ap_id(db_session, liked_object_ap_id) if not liked_object: raise ValueError(f"Cannot find liked object {liked_object_ap_id}") liked_object.liked_via_outbox_object_ap_id = None # Send the Undo to the liked object's actor - new_outgoing_activity( - db, + await new_outgoing_activity( + db_session, liked_object.actor.inbox_url, # type: ignore outbox_object.id, ) @@ -218,21 +222,23 @@ def send_undo(db: Session, ap_object_id: str) -> None: announced_object_ap_id = outbox_object_to_undo.activity_object_ap_id if not announced_object_ap_id: raise ValueError("Should never happen") - announced_object = get_inbox_object_by_ap_id(db, announced_object_ap_id) + announced_object = await get_inbox_object_by_ap_id( + db_session, announced_object_ap_id + ) if not announced_object: raise ValueError(f"Cannot find announced object {announced_object_ap_id}") announced_object.announced_via_outbox_object_ap_id = None # Send the Undo to the original recipients - recipients = _compute_recipients(db, outbox_object.ap_object) + recipients = await _compute_recipients(db_session, outbox_object.ap_object) for rcp in recipients: - new_outgoing_activity(db, rcp, outbox_object.id) + await new_outgoing_activity(db_session, rcp, outbox_object.id) else: raise ValueError("Should never happen") -def send_create( - db: Session, +async def send_create( + db_session: AsyncSession, source: str, uploads: list[tuple[models.Upload, str]], in_reply_to: str | None, @@ -243,11 +249,11 @@ def send_create( note_id = allocate_outbox_id() published = now().replace(microsecond=0).isoformat().replace("+00:00", "Z") context = f"{ID}/contexts/" + uuid.uuid4().hex - content, tags, mentioned_actors = markdownify(db, source) + content, tags, mentioned_actors = await markdownify(db_session, source) attachments = [] if in_reply_to: - in_reply_to_object = get_anybox_object_by_ap_id(db, in_reply_to) + in_reply_to_object = await get_anybox_object_by_ap_id(db_session, in_reply_to) if not in_reply_to_object: raise ValueError(f"Invalid in reply to {in_reply_to=}") if not in_reply_to_object.ap_context: @@ -255,7 +261,7 @@ def send_create( context = in_reply_to_object.ap_context if in_reply_to_object.is_from_outbox: - db.execute( + await db_session.execute( update(models.OutboxObject) .where( models.OutboxObject.ap_id == in_reply_to, @@ -302,7 +308,7 @@ def send_create( "sensitive": is_sensitive, "attachment": attachments, } - outbox_object = save_outbox_object(db, note_id, note, source=source) + outbox_object = await save_outbox_object(db_session, note_id, note, source=source) if not outbox_object.id: raise ValueError("Should never happen") @@ -312,24 +318,26 @@ def send_create( tag=tag["name"][1:], outbox_object_id=outbox_object.id, ) - db.add(tagged_object) + db_session.add(tagged_object) for (upload, filename) in uploads: outbox_object_attachment = models.OutboxObjectAttachment( filename=filename, outbox_object_id=outbox_object.id, upload_id=upload.id ) - db.add(outbox_object_attachment) + db_session.add(outbox_object_attachment) - db.commit() + await db_session.commit() - recipients = _compute_recipients(db, note) + recipients = await _compute_recipients(db_session, note) for rcp in recipients: - new_outgoing_activity(db, rcp, outbox_object.id) + await new_outgoing_activity(db_session, rcp, outbox_object.id) return note_id -def _compute_recipients(db: Session, ap_object: ap.RawObject) -> set[str]: +async def _compute_recipients( + db_session: AsyncSession, ap_object: ap.RawObject +) -> set[str]: _recipients = [] for field in ["to", "cc", "bto", "bcc"]: if field in ap_object: @@ -343,15 +351,17 @@ def _compute_recipients(db: Session, ap_object: ap.RawObject) -> set[str]: # If we got a local collection, assume it's a collection of actors if r.startswith(BASE_URL): - for actor in fetch_actor_collection(db, r): + for actor in await fetch_actor_collection(db_session, r): recipients.add(actor.shared_inbox_url or actor.inbox_url) continue # Is it a known actor? - known_actor = db.execute( - select(models.Actor).where(models.Actor.ap_id == r) - ).scalar_one_or_none() + known_actor = ( + await db_session.execute( + select(models.Actor).where(models.Actor.ap_id == r) + ) + ).scalar_one_or_none() # type: ignore if known_actor: recipients.add(known_actor.shared_inbox_url or known_actor.inbox_url) continue @@ -359,7 +369,7 @@ def _compute_recipients(db: Session, ap_object: ap.RawObject) -> set[str]: # Fetch the object raw_object = ap.fetch(r) if raw_object.get("type") in ap.ACTOR_TYPES: - saved_actor = save_actor(db, raw_object) + saved_actor = await save_actor(db_session, raw_object) recipients.add(saved_actor.shared_inbox_url or saved_actor.inbox_url) else: # Assume it's a collection of actors @@ -370,27 +380,43 @@ def _compute_recipients(db: Session, ap_object: ap.RawObject) -> set[str]: return recipients -def get_inbox_object_by_ap_id(db: Session, ap_id: str) -> models.InboxObject | None: - return db.execute( - select(models.InboxObject).where(models.InboxObject.ap_id == ap_id) - ).scalar_one_or_none() +async def get_inbox_object_by_ap_id( + db_session: AsyncSession, ap_id: str +) -> models.InboxObject | None: + return ( + await db_session.execute( + select(models.InboxObject) + .where(models.InboxObject.ap_id == ap_id) + .options( + joinedload(models.InboxObject.actor), + joinedload(models.InboxObject.relates_to_inbox_object), + joinedload(models.InboxObject.relates_to_outbox_object), + ) + ) + ).scalar_one_or_none() # type: ignore -def get_outbox_object_by_ap_id(db: Session, ap_id: str) -> models.OutboxObject | None: - return db.execute( - select(models.OutboxObject).where(models.OutboxObject.ap_id == ap_id) - ).scalar_one_or_none() +async def get_outbox_object_by_ap_id( + db_session: AsyncSession, ap_id: str +) -> models.OutboxObject | None: + return ( + await db_session.execute( + select(models.OutboxObject).where(models.OutboxObject.ap_id == ap_id) + ) + ).scalar_one_or_none() # type: ignore -def get_anybox_object_by_ap_id(db: Session, ap_id: str) -> AnyboxObject | None: +async def get_anybox_object_by_ap_id( + db_session: AsyncSession, ap_id: str +) -> AnyboxObject | None: if ap_id.startswith(BASE_URL): - return get_outbox_object_by_ap_id(db, ap_id) + return await get_outbox_object_by_ap_id(db_session, ap_id) else: - return get_inbox_object_by_ap_id(db, ap_id) + return await get_inbox_object_by_ap_id(db_session, ap_id) -def _handle_delete_activity( - db: Session, +async def _handle_delete_activity( + db_session: AsyncSession, from_actor: models.Actor, ap_object_to_delete: models.InboxObject, ) -> None: @@ -404,12 +430,12 @@ def _handle_delete_activity( # TODO(ts): do we need to delete related activities? should we keep # bookmarked objects with a deleted flag? logger.info(f"Deleting {ap_object_to_delete.ap_type}/{ap_object_to_delete.ap_id}") - db.delete(ap_object_to_delete) - db.flush() + await db_session.delete(ap_object_to_delete) + await db_session.flush() -def _handle_follow_follow_activity( - db: Session, +async def _handle_follow_follow_activity( + db_session: AsyncSession, from_actor: models.Actor, inbox_object: models.InboxObject, ) -> None: @@ -419,8 +445,8 @@ def _handle_follow_follow_activity( ap_actor_id=from_actor.ap_id, ) try: - db.add(follower) - db.flush() + db_session.add(follower) + await db_session.flush() except IntegrityError: pass # TODO update the existing followe @@ -433,20 +459,20 @@ def _handle_follow_follow_activity( "actor": ID, "object": inbox_object.ap_id, } - outbox_activity = save_outbox_object(db, reply_id, reply) + outbox_activity = await save_outbox_object(db_session, reply_id, reply) if not outbox_activity.id: raise ValueError("Should never happen") - new_outgoing_activity(db, from_actor.inbox_url, outbox_activity.id) + await new_outgoing_activity(db_session, from_actor.inbox_url, outbox_activity.id) notif = models.Notification( notification_type=models.NotificationType.NEW_FOLLOWER, actor_id=from_actor.id, ) - db.add(notif) + db_session.add(notif) -def _handle_undo_activity( - db: Session, +async def _handle_undo_activity( + db_session: AsyncSession, from_actor: models.Actor, undo_activity: models.InboxObject, ap_activity_to_undo: models.InboxObject, @@ -462,7 +488,7 @@ def _handle_undo_activity( if ap_activity_to_undo.ap_type == "Follow": logger.info(f"Undo follow from {from_actor.ap_id}") - db.execute( + await db_session.execute( delete(models.Follower).where( models.Follower.inbox_object_id == ap_activity_to_undo.id ) @@ -471,13 +497,13 @@ def _handle_undo_activity( notification_type=models.NotificationType.UNFOLLOW, actor_id=from_actor.id, ) - db.add(notif) + db_session.add(notif) elif ap_activity_to_undo.ap_type == "Like": if not ap_activity_to_undo.activity_object_ap_id: raise ValueError("Like without object") - liked_obj = get_outbox_object_by_ap_id( - db, + liked_obj = await get_outbox_object_by_ap_id( + db_session, ap_activity_to_undo.activity_object_ap_id, ) if not liked_obj: @@ -494,7 +520,7 @@ def _handle_undo_activity( outbox_object_id=liked_obj.id, inbox_object_id=ap_activity_to_undo.id, ) - db.add(notif) + db_session.add(notif) elif ap_activity_to_undo.ap_type == "Announce": if not ap_activity_to_undo.activity_object_ap_id: @@ -504,8 +530,8 @@ def _handle_undo_activity( f"Undo for announce {ap_activity_to_undo.ap_id}/{announced_obj_ap_id}" ) if announced_obj_ap_id.startswith(BASE_URL): - announced_obj_from_outbox = get_outbox_object_by_ap_id( - db, announced_obj_ap_id + announced_obj_from_outbox = await get_outbox_object_by_ap_id( + db_session, announced_obj_ap_id ) if announced_obj_from_outbox: logger.info("Found in the oubox") @@ -518,7 +544,7 @@ def _handle_undo_activity( outbox_object_id=announced_obj_from_outbox.id, inbox_object_id=ap_activity_to_undo.id, ) - db.add(notif) + db_session.add(notif) # FIXME(ts): what to do with ap_activity_to_undo? flag? delete? else: @@ -527,8 +553,8 @@ def _handle_undo_activity( # commit will be perfomed in save_to_inbox -def _handle_create_activity( - db: Session, +async def _handle_create_activity( + db_session: AsyncSession, from_actor: models.Actor, created_object: models.InboxObject, ) -> None: @@ -544,7 +570,7 @@ def _handle_create_activity( return None if created_object.in_reply_to and created_object.in_reply_to.startswith(BASE_URL): - db.execute( + await db_session.execute( update(models.OutboxObject) .where( models.OutboxObject.ap_id == created_object.in_reply_to, @@ -559,12 +585,12 @@ def _handle_create_activity( actor_id=from_actor.id, inbox_object_id=created_object.id, ) - db.add(notif) + db_session.add(notif) -def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: +async def save_to_inbox(db_session: AsyncSession, raw_object: ap.RawObject) -> None: try: - actor = fetch_actor(db, ap.get_id(raw_object["actor"])) + actor = await fetch_actor(db_session, ap.get_id(raw_object["actor"])) except httpx.HTTPStatusError: logger.exception("Failed to fetch actor") return @@ -576,7 +602,7 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: ra = RemoteObject(ap.unwrap_activity(raw_object), actor=actor) if ( - db.scalar( + await db_session.scalar( select(func.count(models.InboxObject.id)).where( models.InboxObject.ap_id == ra.ap_id ) @@ -590,13 +616,13 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: relates_to_outbox_object: models.OutboxObject | None = None if ra.activity_object_ap_id: if ra.activity_object_ap_id.startswith(BASE_URL): - relates_to_outbox_object = get_outbox_object_by_ap_id( - db, + relates_to_outbox_object = await get_outbox_object_by_ap_id( + db_session, ra.activity_object_ap_id, ) else: - relates_to_inbox_object = get_inbox_object_by_ap_id( - db, + relates_to_inbox_object = await get_inbox_object_by_ap_id( + db_session, ra.activity_object_ap_id, ) @@ -625,27 +651,29 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: ), # TODO: handle mentions ) - db.add(inbox_object) - db.flush() - db.refresh(inbox_object) + db_session.add(inbox_object) + await db_session.flush() + await db_session.refresh(inbox_object) if ra.ap_type == "Note": # TODO: handle create better - _handle_create_activity(db, actor, inbox_object) + await _handle_create_activity(db_session, actor, inbox_object) elif ra.ap_type == "Update": pass elif ra.ap_type == "Delete": if relates_to_inbox_object: - _handle_delete_activity(db, actor, relates_to_inbox_object) + await _handle_delete_activity(db_session, actor, relates_to_inbox_object) else: # TODO(ts): handle delete actor logger.info( f"Received a Delete for an unknown object: {ra.activity_object_ap_id}" ) elif ra.ap_type == "Follow": - _handle_follow_follow_activity(db, actor, inbox_object) + await _handle_follow_follow_activity(db_session, actor, inbox_object) elif ra.ap_type == "Undo": if relates_to_inbox_object: - _handle_undo_activity(db, actor, inbox_object, relates_to_inbox_object) + await _handle_undo_activity( + db_session, actor, inbox_object, relates_to_inbox_object + ) else: logger.info("Received Undo for an unknown activity") elif ra.ap_type in ["Accept", "Reject"]: @@ -661,7 +689,7 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: outbox_object_id=relates_to_outbox_object.id, ap_actor_id=actor.ap_id, ) - db.add(following) + db_session.add(following) else: logger.info( "Received an Accept for an unsupported activity: " @@ -689,7 +717,7 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: outbox_object_id=relates_to_outbox_object.id, inbox_object_id=inbox_object.id, ) - db.add(notif) + db_session.add(notif) elif raw_object["type"] == "Announce": if relates_to_outbox_object: # This is an announce for a local object @@ -703,7 +731,7 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: outbox_object_id=relates_to_outbox_object.id, inbox_object_id=inbox_object.id, ) - db.add(notif) + db_session.add(notif) else: # This is announce for a maybe unknown object if relates_to_inbox_object: @@ -713,7 +741,9 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: if not ra.activity_object_ap_id: raise ValueError("Should never happen") announced_raw_object = ap.fetch(ra.activity_object_ap_id) - announced_actor = fetch_actor(db, ap.get_actor_id(announced_raw_object)) + announced_actor = await fetch_actor( + db_session, ap.get_actor_id(announced_raw_object) + ) announced_object = RemoteObject(announced_raw_object, announced_actor) announced_inbox_object = models.InboxObject( server=urlparse(announced_object.ap_id).netloc, @@ -727,8 +757,8 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: visibility=announced_object.visibility, is_hidden_from_stream=True, ) - db.add(announced_inbox_object) - db.flush() + db_session.add(announced_inbox_object) + await db_session.flush() inbox_object.relates_to_inbox_object_id = announced_inbox_object.id elif ra.ap_type in ["Like", "Announce"]: if not relates_to_outbox_object: @@ -749,7 +779,7 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: outbox_object_id=relates_to_outbox_object.id, inbox_object_id=inbox_object.id, ) - db.add(notif) + db_session.add(notif) elif raw_object["type"] == "Announce": # TODO(ts): notification relates_to_outbox_object.announces_count = ( @@ -762,18 +792,18 @@ def save_to_inbox(db: Session, raw_object: ap.RawObject) -> None: outbox_object_id=relates_to_outbox_object.id, inbox_object_id=inbox_object.id, ) - db.add(notif) + db_session.add(notif) else: raise ValueError("Should never happen") else: logger.warning(f"Received an unknown {inbox_object.ap_type} object") - db.commit() + await db_session.commit() -def public_outbox_objects_count(db: Session) -> int: - return db.scalar( +async def public_outbox_objects_count(db_session: AsyncSession) -> int: + return await db_session.scalar( select(func.count(models.OutboxObject.id)).where( models.OutboxObject.visibility == ap.VisibilityEnum.PUBLIC, models.OutboxObject.is_deleted.is_(False), @@ -781,12 +811,16 @@ def public_outbox_objects_count(db: Session) -> int: ) -def fetch_actor_collection(db: Session, url: str) -> list[Actor]: +async def fetch_actor_collection(db_session: AsyncSession, url: str) -> list[Actor]: if url.startswith(config.BASE_URL): if url == config.BASE_URL + "/followers": followers = ( - db.scalars( - select(models.Follower).options(joinedload(models.Follower.actor)) + ( + await db_session.scalars( + select(models.Follower).options( + joinedload(models.Follower.actor) + ) + ) ) .unique() .all() @@ -806,24 +840,28 @@ class ReplyTreeNode: is_root: bool = False -def get_replies_tree( - db: Session, +async def get_replies_tree( + db_session: AsyncSession, requested_object: AnyboxObject, ) -> ReplyTreeNode: # TODO: handle visibility tree_nodes: list[AnyboxObject] = [] tree_nodes.extend( - db.scalars( - select(models.InboxObject).where( - models.InboxObject.ap_context == requested_object.ap_context, + ( + await db_session.scalars( + select(models.InboxObject).where( + models.InboxObject.ap_context == requested_object.ap_context, + ) ) ).all() ) tree_nodes.extend( - db.scalars( - select(models.OutboxObject).where( - models.OutboxObject.ap_context == requested_object.ap_context, - models.OutboxObject.is_deleted.is_(False), + ( + await db_session.scalars( + select(models.OutboxObject).where( + models.OutboxObject.ap_context == requested_object.ap_context, + models.OutboxObject.is_deleted.is_(False), + ) ) ).all() ) diff --git a/app/config.py b/app/config.py index ac65ebd..2029ad8 100644 --- a/app/config.py +++ b/app/config.py @@ -33,7 +33,7 @@ class Config(pydantic.BaseModel): debug: bool = False # Config items to make tests easier - sqlalchemy_database_url: str | None = None + sqlalchemy_database: str | None = None key_path: str | None = None @@ -73,8 +73,8 @@ ID = f"{_SCHEME}://{DOMAIN}" USERNAME = CONFIG.username BASE_URL = ID DEBUG = CONFIG.debug -DB_PATH = ROOT_DIR / "data" / "microblogpub.db" -SQLALCHEMY_DATABASE_URL = CONFIG.sqlalchemy_database_url or f"sqlite:///{DB_PATH}" +DB_PATH = CONFIG.sqlalchemy_database or ROOT_DIR / "data" / "microblogpub.db" +SQLALCHEMY_DATABASE_URL = f"sqlite:///{DB_PATH}" KEY_PATH = ( (ROOT_DIR / CONFIG.key_path) if CONFIG.key_path else ROOT_DIR / "data" / "key.pem" ) diff --git a/app/database.py b/app/database.py index 3aca462..1e2e31d 100644 --- a/app/database.py +++ b/app/database.py @@ -1,12 +1,14 @@ import datetime from typing import Any -from typing import Generator +from typing import AsyncGenerator from sqlalchemy import create_engine +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from app.config import DB_PATH from app.config import SQLALCHEMY_DATABASE_URL engine = create_engine( @@ -14,6 +16,10 @@ engine = create_engine( ) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +DATABASE_URL = f"sqlite+aiosqlite:///{DB_PATH}" +async_engine = create_async_engine(DATABASE_URL, future=True, echo=False) +async_session = sessionmaker(async_engine, class_=AsyncSession, expire_on_commit=False) + Base: Any = declarative_base() @@ -21,9 +27,6 @@ def now() -> datetime.datetime: return datetime.datetime.now(datetime.timezone.utc) -def get_db() -> Generator[Session, None, None]: - db = SessionLocal() - try: - yield db - finally: - db.close() +async def get_db_session() -> AsyncGenerator[AsyncSession, None]: + async with async_session() as session: + yield session diff --git a/app/lookup.py b/app/lookup.py index 63c6d1a..0480e90 100644 --- a/app/lookup.py +++ b/app/lookup.py @@ -1,14 +1,14 @@ import mf2py # type: ignore -from sqlalchemy.orm import Session from app import activitypub as ap from app import webfinger from app.actor import Actor from app.actor import fetch_actor from app.ap_object import RemoteObject +from app.database import AsyncSession -def lookup(db: Session, query: str) -> Actor | RemoteObject: +async def lookup(db_session: AsyncSession, query: str) -> Actor | RemoteObject: if query.startswith("@"): query = webfinger.get_actor_url(query) # type: ignore # None check below @@ -34,7 +34,7 @@ def lookup(db: Session, query: str) -> Actor | RemoteObject: raise if ap_obj["type"] in ap.ACTOR_TYPES: - actor = fetch_actor(db, ap_obj["id"]) + actor = await fetch_actor(db_session, ap_obj["id"]) return actor else: return RemoteObject(ap_obj) diff --git a/app/main.py b/app/main.py index 8c260d9..325e572 100644 --- a/app/main.py +++ b/app/main.py @@ -24,7 +24,6 @@ from loguru import logger from PIL import Image from sqlalchemy import func from sqlalchemy import select -from sqlalchemy.orm import Session from sqlalchemy.orm import joinedload from starlette.background import BackgroundTask from starlette.responses import JSONResponse @@ -49,7 +48,8 @@ from app.config import USERNAME from app.config import generate_csrf_token from app.config import is_activitypub_requested from app.config import verify_csrf_token -from app.database import get_db +from app.database import AsyncSession +from app.database import get_db_session from app.templates import is_current_user_admin from app.uploads import UPLOAD_DIR from app.utils import pagination @@ -139,9 +139,9 @@ class ActivityPubResponse(JSONResponse): @app.get("/") -def index( +async def index( request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), _: httpsig.HTTPSigInfo = Depends(httpsig.httpsig_checker), page: int | None = None, ) -> templates.TemplateResponse | ActivityPubResponse: @@ -155,27 +155,26 @@ def index( models.OutboxObject.is_hidden_from_homepage.is_(False), ) q = select(models.OutboxObject).where(*where) - total_count = db.scalar(select(func.count(models.OutboxObject.id)).where(*where)) + total_count = await db_session.scalar( + select(func.count(models.OutboxObject.id)).where(*where) + ) page_size = 20 page_offset = (page - 1) * page_size - outbox_objects = ( - db.scalars( - q.options( - joinedload(models.OutboxObject.outbox_object_attachments).options( - joinedload(models.OutboxObjectAttachment.upload) - ) + outbox_objects_result = await db_session.scalars( + q.options( + joinedload(models.OutboxObject.outbox_object_attachments).options( + joinedload(models.OutboxObjectAttachment.upload) ) - .order_by(models.OutboxObject.ap_published_at.desc()) - .offset(page_offset) - .limit(page_size) ) - .unique() - .all() + .order_by(models.OutboxObject.ap_published_at.desc()) + .offset(page_offset) + .limit(page_size) ) + outbox_objects = outbox_objects_result.unique().all() - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "index.html", { @@ -188,14 +187,14 @@ def index( ) -def _build_followx_collection( - db: Session, +async def _build_followx_collection( + db_session: AsyncSession, model_cls: Type[models.Following | models.Follower], path: str, page: bool | None, next_cursor: str | None, ) -> ap.RawObject: - total_items = db.query(model_cls).count() + total_items = await db_session.scalar(select(func.count(model_cls.id))) if not page and not next_cursor: return { @@ -213,11 +212,11 @@ def _build_followx_collection( ) q = q.limit(20) - items = [followx for followx in db.scalars(q).all()] + items = [followx for followx in (await db_session.scalars(q)).all()] next_cursor = None if ( items - and db.scalar( + and await db_session.scalar( select(func.count(model_cls.id)).where( model_cls.created_at < items[-1].created_at ) @@ -244,18 +243,18 @@ def _build_followx_collection( @app.get("/followers") -def followers( +async def followers( request: Request, page: bool | None = None, next_cursor: str | None = None, prev_cursor: str | None = None, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), _: httpsig.HTTPSigInfo = Depends(httpsig.httpsig_checker), ) -> ActivityPubResponse | templates.TemplateResponse: if is_activitypub_requested(request): return ActivityPubResponse( - _build_followx_collection( - db=db, + await _build_followx_collection( + db_session=db_session, model_cls=models.Follower, path="/followers", page=page, @@ -264,26 +263,23 @@ def followers( ) # We only show the most recent 20 followers on the public website - followers = ( - db.scalars( - select(models.Follower) - .options(joinedload(models.Follower.actor)) - .order_by(models.Follower.created_at.desc()) - .limit(20) - ) - .unique() - .all() + followers_result = await db_session.scalars( + select(models.Follower) + .options(joinedload(models.Follower.actor)) + .order_by(models.Follower.created_at.desc()) + .limit(20) ) + followers = followers_result.unique().all() actors_metadata = {} if is_current_user_admin(request): - actors_metadata = get_actors_metadata( - db, + actors_metadata = await get_actors_metadata( + db_session, [f.actor for f in followers], ) - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "followers.html", { @@ -294,18 +290,18 @@ def followers( @app.get("/following") -def following( +async def following( request: Request, page: bool | None = None, next_cursor: str | None = None, prev_cursor: str | None = None, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), _: httpsig.HTTPSigInfo = Depends(httpsig.httpsig_checker), ) -> ActivityPubResponse | templates.TemplateResponse: if is_activitypub_requested(request): return ActivityPubResponse( - _build_followx_collection( - db=db, + await _build_followx_collection( + db_session=db_session, model_cls=models.Following, path="/following", page=page, @@ -315,10 +311,12 @@ def following( # We only show the most recent 20 follows on the public website following = ( - db.scalars( - select(models.Following) - .options(joinedload(models.Following.actor)) - .order_by(models.Following.created_at.desc()) + ( + await db_session.scalars( + select(models.Following) + .options(joinedload(models.Following.actor)) + .order_by(models.Following.created_at.desc()) + ) ) .unique() .all() @@ -327,13 +325,13 @@ def following( # TODO: support next_cursor/prev_cursor actors_metadata = {} if is_current_user_admin(request): - actors_metadata = get_actors_metadata( - db, + actors_metadata = await get_actors_metadata( + db_session, [f.actor for f in following], ) - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "following.html", { @@ -344,19 +342,21 @@ def following( @app.get("/outbox") -def outbox( - db: Session = Depends(get_db), +async def outbox( + db_session: AsyncSession = Depends(get_db_session), _: httpsig.HTTPSigInfo = Depends(httpsig.httpsig_checker), ) -> ActivityPubResponse: # By design, we only show the last 20 public activities in the oubox - outbox_objects = db.scalars( - select(models.OutboxObject) - .where( - models.OutboxObject.visibility == ap.VisibilityEnum.PUBLIC, - models.OutboxObject.is_deleted.is_(False), + outbox_objects = ( + await db_session.scalars( + select(models.OutboxObject) + .where( + models.OutboxObject.visibility == ap.VisibilityEnum.PUBLIC, + models.OutboxObject.is_deleted.is_(False), + ) + .order_by(models.OutboxObject.ap_published_at.desc()) + .limit(20) ) - .order_by(models.OutboxObject.ap_published_at.desc()) - .limit(20) ).all() return ActivityPubResponse( { @@ -373,19 +373,21 @@ def outbox( @app.get("/featured") -def featured( - db: Session = Depends(get_db), +async def featured( + db_session: AsyncSession = Depends(get_db_session), _: httpsig.HTTPSigInfo = Depends(httpsig.httpsig_checker), ) -> ActivityPubResponse: - outbox_objects = db.scalars( - select(models.OutboxObject) - .filter( - models.OutboxObject.visibility == ap.VisibilityEnum.PUBLIC, - models.OutboxObject.is_deleted.is_(False), - models.OutboxObject.is_pinned.is_(True), + outbox_objects = ( + await db_session.scalars( + select(models.OutboxObject) + .filter( + models.OutboxObject.visibility == ap.VisibilityEnum.PUBLIC, + models.OutboxObject.is_deleted.is_(False), + models.OutboxObject.is_pinned.is_(True), + ) + .order_by(models.OutboxObject.ap_published_at.desc()) + .limit(5) ) - .order_by(models.OutboxObject.ap_published_at.desc()) - .limit(5) ).all() return ActivityPubResponse( { @@ -398,9 +400,9 @@ def featured( ) -def _check_outbox_object_acl( +async def _check_outbox_object_acl( request: Request, - db: Session, + db_session: AsyncSession, ap_object: models.OutboxObject, httpsig_info: httpsig.HTTPSigInfo, ) -> None: @@ -413,7 +415,9 @@ def _check_outbox_object_acl( ]: return None elif ap_object.visibility == ap.VisibilityEnum.FOLLOWERS_ONLY: - followers = boxes.fetch_actor_collection(db, BASE_URL + "/followers") + followers = await boxes.fetch_actor_collection( + db_session, BASE_URL + "/followers" + ) if httpsig_info.signed_by_ap_actor_id in [actor.ap_id for actor in followers]: return None elif ap_object.visibility == ap.VisibilityEnum.DIRECT: @@ -425,23 +429,25 @@ def _check_outbox_object_acl( @app.get("/o/{public_id}") -def outbox_by_public_id( +async def outbox_by_public_id( public_id: str, request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), httpsig_info: httpsig.HTTPSigInfo = Depends(httpsig.httpsig_checker), ) -> ActivityPubResponse | templates.TemplateResponse: maybe_object = ( - db.execute( - select(models.OutboxObject) - .options( - joinedload(models.OutboxObject.outbox_object_attachments).options( - joinedload(models.OutboxObjectAttachment.upload) + ( + await db_session.execute( + select(models.OutboxObject) + .options( + joinedload(models.OutboxObject.outbox_object_attachments).options( + joinedload(models.OutboxObjectAttachment.upload) + ) + ) + .where( + models.OutboxObject.public_id == public_id, + models.OutboxObject.is_deleted.is_(False), ) - ) - .where( - models.OutboxObject.public_id == public_id, - models.OutboxObject.is_deleted.is_(False), ) ) .unique() @@ -450,45 +456,49 @@ def outbox_by_public_id( if not maybe_object: raise HTTPException(status_code=404) - _check_outbox_object_acl(request, db, maybe_object, httpsig_info) + await _check_outbox_object_acl(request, db_session, maybe_object, httpsig_info) if is_activitypub_requested(request): return ActivityPubResponse(maybe_object.ap_object) - replies_tree = boxes.get_replies_tree(db, maybe_object) + replies_tree = await boxes.get_replies_tree(db_session, maybe_object) likes = ( - db.scalars( - select(models.InboxObject) - .where( - models.InboxObject.ap_type == "Like", - models.InboxObject.activity_object_ap_id == maybe_object.ap_id, + ( + await db_session.scalars( + select(models.InboxObject) + .where( + models.InboxObject.ap_type == "Like", + models.InboxObject.activity_object_ap_id == maybe_object.ap_id, + ) + .options(joinedload(models.InboxObject.actor)) + .order_by(models.InboxObject.ap_published_at.desc()) + .limit(10) ) - .options(joinedload(models.InboxObject.actor)) - .order_by(models.InboxObject.ap_published_at.desc()) - .limit(10) ) .unique() .all() ) shares = ( - db.scalars( - select(models.InboxObject) - .filter( - models.InboxObject.ap_type == "Announce", - models.InboxObject.activity_object_ap_id == maybe_object.ap_id, + ( + await db_session.scalars( + select(models.InboxObject) + .filter( + models.InboxObject.ap_type == "Announce", + models.InboxObject.activity_object_ap_id == maybe_object.ap_id, + ) + .options(joinedload(models.InboxObject.actor)) + .order_by(models.InboxObject.ap_published_at.desc()) + .limit(10) ) - .options(joinedload(models.InboxObject.actor)) - .order_by(models.InboxObject.ap_published_at.desc()) - .limit(10) ) .unique() .all() ) - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "object.html", { @@ -501,31 +511,33 @@ def outbox_by_public_id( @app.get("/o/{public_id}/activity") -def outbox_activity_by_public_id( +async def outbox_activity_by_public_id( public_id: str, request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), httpsig_info: httpsig.HTTPSigInfo = Depends(httpsig.httpsig_checker), ) -> ActivityPubResponse: - maybe_object = db.execute( - select(models.OutboxObject).where( - models.OutboxObject.public_id == public_id, - models.OutboxObject.is_deleted.is_(False), + maybe_object = ( + await db_session.execute( + select(models.OutboxObject).where( + models.OutboxObject.public_id == public_id, + models.OutboxObject.is_deleted.is_(False), + ) ) ).scalar_one_or_none() if not maybe_object: raise HTTPException(status_code=404) - _check_outbox_object_acl(request, db, maybe_object, httpsig_info) + await _check_outbox_object_acl(request, db_session, maybe_object, httpsig_info) return ActivityPubResponse(ap.wrap_object(maybe_object.ap_object)) @app.get("/t/{tag}") -def tag_by_name( +async def tag_by_name( tag: str, request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), _: httpsig.HTTPSigInfo = Depends(httpsig.httpsig_checker), ) -> ActivityPubResponse | templates.TemplateResponse: # TODO(ts): implement HTML version @@ -554,23 +566,23 @@ def emoji_by_name(name: str) -> ActivityPubResponse: @app.post("/inbox") async def inbox( request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), httpsig_info: httpsig.HTTPSigInfo = Depends(httpsig.enforce_httpsig), ) -> Response: logger.info(f"headers={request.headers}") payload = await request.json() logger.info(f"{payload=}") - save_to_inbox(db, payload) + await save_to_inbox(db_session, payload) return Response(status_code=204) @app.get("/remote_follow") -def get_remote_follow( +async def get_remote_follow( request: Request, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ) -> templates.TemplateResponse: - return templates.render_template( - db, + return await templates.render_template( + db_session, request, "remote_follow.html", {"remote_follow_csrf_token": generate_csrf_token()}, @@ -578,9 +590,8 @@ def get_remote_follow( @app.post("/remote_follow") -def post_remote_follow( +async def post_remote_follow( request: Request, - db: Session = Depends(get_db), csrf_check: None = Depends(verify_csrf_token), profile: str = Form(), ) -> RedirectResponse: @@ -598,7 +609,7 @@ def post_remote_follow( @app.get("/.well-known/webfinger") -def wellknown_webfinger(resource: str) -> JSONResponse: +async def wellknown_webfinger(resource: str) -> JSONResponse: """Exposes/servers WebFinger data.""" omg = f"acct:{USERNAME}@{DOMAIN}" logger.info(f"{resource == omg}/{resource}/{omg}/{len(resource)}/{len(omg)}") @@ -639,10 +650,10 @@ async def well_known_nodeinfo() -> dict[str, Any]: @app.get("/nodeinfo") -def nodeinfo( - db: Session = Depends(get_db), +async def nodeinfo( + db_session: AsyncSession = Depends(get_db_session), ): - local_posts = public_outbox_objects_count(db) + local_posts = await public_outbox_objects_count(db_session) return JSONResponse( { "version": "2.1", @@ -780,14 +791,16 @@ def serve_proxy_media_resized( @app.get("/attachments/{content_hash}/{filename}") -def serve_attachment( +async def serve_attachment( content_hash: str, filename: str, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ): - upload = db.execute( - select(models.Upload).where( - models.Upload.content_hash == content_hash, + upload = ( + await db_session.execute( + select(models.Upload).where( + models.Upload.content_hash == content_hash, + ) ) ).scalar_one_or_none() if not upload: @@ -800,14 +813,16 @@ def serve_attachment( @app.get("/attachments/thumbnails/{content_hash}/{filename}") -def serve_attachment_thumbnail( +async def serve_attachment_thumbnail( content_hash: str, filename: str, - db: Session = Depends(get_db), + db_session: AsyncSession = Depends(get_db_session), ): - upload = db.execute( - select(models.Upload).where( - models.Upload.content_hash == content_hash, + upload = ( + await db_session.execute( + select(models.Upload).where( + models.Upload.content_hash == content_hash, + ) ) ).scalar_one_or_none() if not upload or not upload.has_thumbnail: @@ -827,24 +842,35 @@ Disallow: /following Disallow: /admin""" -def _get_outbox_for_feed(db: Session) -> list[models.OutboxObject]: - return db.scalars( - select(models.OutboxObject) - .where( - models.OutboxObject.visibility == ap.VisibilityEnum.PUBLIC, - models.OutboxObject.is_deleted.is_(False), - models.OutboxObject.ap_type.in_(["Note", "Article", "Video"]), +async def _get_outbox_for_feed(db_session: AsyncSession) -> list[models.OutboxObject]: + return ( + ( + await db_session.scalars( + select(models.OutboxObject) + .where( + models.OutboxObject.visibility == ap.VisibilityEnum.PUBLIC, + models.OutboxObject.is_deleted.is_(False), + models.OutboxObject.ap_type.in_(["Note", "Article", "Video"]), + ) + .options( + joinedload(models.OutboxObject.outbox_object_attachments).options( + joinedload(models.OutboxObjectAttachment.upload) + ) + ) + .order_by(models.OutboxObject.ap_published_at.desc()) + .limit(20) + ) ) - .order_by(models.OutboxObject.ap_published_at.desc()) - .limit(20) - ).all() + .unique() + .all() + ) @app.get("/feed.json") -def json_feed( - db: Session = Depends(get_db), +async def json_feed( + db_session: AsyncSession = Depends(get_db_session), ) -> dict[str, Any]: - outbox_objects = _get_outbox_for_feed(db) + outbox_objects = await _get_outbox_for_feed(db_session) data = [] for outbox_object in outbox_objects: if not outbox_object.ap_published_at: @@ -876,8 +902,8 @@ def json_feed( } -def _gen_rss_feed( - db: Session, +async def _gen_rss_feed( + db_session: AsyncSession, ): fg = FeedGenerator() fg.id(BASE_URL + "/feed.rss") @@ -888,7 +914,7 @@ def _gen_rss_feed( fg.logo(LOCAL_ACTOR.icon_url) fg.language("en") - outbox_objects = _get_outbox_for_feed(db) + outbox_objects = await _get_outbox_for_feed(db_session) for outbox_object in outbox_objects: if not outbox_object.ap_published_at: raise ValueError(f"{outbox_object} has no published date") @@ -904,20 +930,20 @@ def _gen_rss_feed( @app.get("/feed.rss") -def rss_feed( - db: Session = Depends(get_db), +async def rss_feed( + db_session: AsyncSession = Depends(get_db_session), ) -> PlainTextResponse: return PlainTextResponse( - _gen_rss_feed(db).rss_str(), + (await _gen_rss_feed(db_session)).rss_str(), headers={"Content-Type": "application/rss+xml"}, ) @app.get("/feed.atom") -def atom_feed( - db: Session = Depends(get_db), +async def atom_feed( + db_session: AsyncSession = Depends(get_db_session), ) -> PlainTextResponse: return PlainTextResponse( - _gen_rss_feed(db).atom_str(), + (await _gen_rss_feed(db_session)).atom_str(), headers={"Content-Type": "application/atom+xml"}, ) diff --git a/app/outgoing_activities.py b/app/outgoing_activities.py index 7b965e2..ed180de 100644 --- a/app/outgoing_activities.py +++ b/app/outgoing_activities.py @@ -11,15 +11,23 @@ from sqlalchemy import select from sqlalchemy.orm import Session from app import activitypub as ap +from app import config +from app import ldsig from app import models +from app.database import AsyncSession from app.database import SessionLocal from app.database import now +from app.key import Key +from app.key import get_key _MAX_RETRIES = 16 +k = Key(config.ID, f"{config.ID}#main-key") +k.load(get_key()) -def new_outgoing_activity( - db: Session, + +async def new_outgoing_activity( + db_session: AsyncSession, recipient: str, outbox_object_id: int, ) -> models.OutgoingActivity: @@ -28,9 +36,9 @@ def new_outgoing_activity( outbox_object_id=outbox_object_id, ) - db.add(outgoing_activity) - db.commit() - db.refresh(outgoing_activity) + db_session.add(outgoing_activity) + await db_session.commit() + await db_session.refresh(outgoing_activity) return outgoing_activity @@ -91,6 +99,8 @@ def process_next_outgoing_activity(db: Session) -> bool: next_activity.last_try = now() payload = ap.wrap_object_if_needed(next_activity.outbox_object.ap_object) + if payload["type"] == "Create": + ldsig.generate_signature(payload, k) logger.info(f"{payload=}") try: resp = ap.post(next_activity.recipient, payload) diff --git a/app/source.py b/app/source.py index f8b366b..1ae0a87 100644 --- a/app/source.py +++ b/app/source.py @@ -2,13 +2,13 @@ import re from markdown import markdown from sqlalchemy import select -from sqlalchemy.orm import Session from app import models from app import webfinger from app.actor import Actor from app.actor import fetch_actor from app.config import BASE_URL +from app.database import AsyncSession from app.utils import emoji @@ -24,7 +24,9 @@ _HASHTAG_REGEX = re.compile(r"(#[\d\w]+)") _MENTION_REGEX = re.compile(r"@[\d\w_.+-]+@[\d\w-]+\.[\d\w\-.]+") -def _hashtagify(db: Session, content: str) -> tuple[str, list[dict[str, str]]]: +async def _hashtagify( + db_session: AsyncSession, content: str +) -> tuple[str, list[dict[str, str]]]: tags = [] hashtags = re.findall(_HASHTAG_REGEX, content) hashtags = sorted(set(hashtags), reverse=True) # unique tags, longest first @@ -36,23 +38,25 @@ def _hashtagify(db: Session, content: str) -> tuple[str, list[dict[str, str]]]: return content, tags -def _mentionify( - db: Session, +async def _mentionify( + db_session: AsyncSession, content: str, ) -> tuple[str, list[dict[str, str]], list[Actor]]: tags = [] mentioned_actors = [] for mention in re.findall(_MENTION_REGEX, content): _, username, domain = mention.split("@") - actor = db.execute( - select(models.Actor).where(models.Actor.handle == mention) + actor = ( + await db_session.execute( + select(models.Actor).where(models.Actor.handle == mention) + ) ).scalar_one_or_none() if not actor: actor_url = webfinger.get_actor_url(mention) if not actor_url: # FIXME(ts): raise an error? continue - actor = fetch_actor(db, actor_url) + actor = await fetch_actor(db_session, actor_url) mentioned_actors.append(actor) tags.append(dict(type="Mention", href=actor.url, name=mention)) @@ -62,8 +66,8 @@ def _mentionify( return content, tags, mentioned_actors -def markdownify( - db: Session, +async def markdownify( + db_session: AsyncSession, content: str, mentionify: bool = True, hashtagify: bool = True, @@ -75,10 +79,10 @@ def markdownify( tags = [] mentioned_actors: list[Actor] = [] if hashtagify: - content, hashtag_tags = _hashtagify(db, content) + content, hashtag_tags = await _hashtagify(db_session, content) tags.extend(hashtag_tags) if mentionify: - content, mention_tags, mentioned_actors = _mentionify(db, content) + content, mention_tags, mentioned_actors = await _mentionify(db_session, content) tags.extend(mention_tags) # Handle custom emoji diff --git a/app/templates.py b/app/templates.py index 695ab74..280e0f6 100644 --- a/app/templates.py +++ b/app/templates.py @@ -15,7 +15,6 @@ from fastapi.templating import Jinja2Templates from loguru import logger from sqlalchemy import func from sqlalchemy import select -from sqlalchemy.orm import Session from starlette.templating import _TemplateResponse as TemplateResponse from app import activitypub as ap @@ -29,6 +28,7 @@ from app.config import DEBUG from app.config import VERSION from app.config import generate_csrf_token from app.config import session_serializer +from app.database import AsyncSession from app.database import now from app.media import proxied_media_url from app.utils.highlight import HIGHLIGHT_CSS @@ -77,8 +77,8 @@ def is_current_user_admin(request: Request) -> bool: return is_admin -def render_template( - db: Session, +async def render_template( + db_session: AsyncSession, request: Request, template: str, template_args: dict[str, Any] = {}, @@ -96,7 +96,7 @@ def render_template( "csrf_token": generate_csrf_token() if is_admin else None, "highlight_css": HIGHLIGHT_CSS, "visibility_enum": ap.VisibilityEnum, - "notifications_count": db.scalar( + "notifications_count": await db_session.scalar( select(func.count(models.Notification.id)).where( models.Notification.is_new.is_(True) ) @@ -104,8 +104,12 @@ def render_template( if is_admin else 0, "local_actor": LOCAL_ACTOR, - "followers_count": db.scalar(select(func.count(models.Follower.id))), - "following_count": db.scalar(select(func.count(models.Following.id))), + "followers_count": await db_session.scalar( + select(func.count(models.Follower.id)) + ), + "following_count": await db_session.scalar( + select(func.count(models.Following.id)) + ), **template_args, }, ) diff --git a/app/uploads.py b/app/uploads.py index f1b241c..07a9135 100644 --- a/app/uploads.py +++ b/app/uploads.py @@ -11,12 +11,12 @@ from app import activitypub as ap from app import models from app.config import BASE_URL from app.config import ROOT_DIR -from app.database import Session +from app.database import AsyncSession UPLOAD_DIR = ROOT_DIR / "data" / "uploads" -def save_upload(db: Session, f: UploadFile) -> models.Upload: +async def save_upload(db_session: AsyncSession, f: UploadFile) -> models.Upload: # Compute the hash h = hashlib.blake2b(digest_size=32) while True: @@ -28,8 +28,10 @@ def save_upload(db: Session, f: UploadFile) -> models.Upload: content_hash = h.hexdigest() f.file.seek(0) - existing_upload = db.execute( - select(models.Upload).where(models.Upload.content_hash == content_hash) + existing_upload = ( + await db_session.execute( + select(models.Upload).where(models.Upload.content_hash == content_hash) + ) ).scalar_one_or_none() if existing_upload: logger.info(f"Upload with {content_hash=} already exists") @@ -88,8 +90,8 @@ def save_upload(db: Session, f: UploadFile) -> models.Upload: width=width, height=height, ) - db.add(new_upload) - db.commit() + db_session.add(new_upload) + await db_session.commit() return new_upload diff --git a/data/tests.toml b/data/tests.toml index dabae6a..cf7b5f9 100644 --- a/data/tests.toml +++ b/data/tests.toml @@ -10,7 +10,7 @@ secret = "1dd4079e0474d1a519052b8fe3cb5fa6" debug = true # In-mem DB -sqlalchemy_database_url = "sqlite:///file:pytest?mode=memory&cache=shared&uri=true" -# sqlalchemy_database_url = "sqlite:///data/pytest.db" +sqlalchemy_database = "file:pytest?mode=memory&cache=shared&uri=true" +# sqlalchemy_database_url = "data/pytest.db" key_path = "tests/test.key" media_db_path = "tests/media.db" diff --git a/poetry.lock b/poetry.lock index 24b92fe..d626970 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,3 +1,14 @@ +[[package]] +name = "aiosqlite" +version = "0.17.0" +description = "asyncio bridge to the standard sqlite3 module" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +typing_extensions = ">=3.7.2" + [[package]] name = "alembic" version = "1.8.0" @@ -590,7 +601,7 @@ requests = ">=2.18.4" name = "mypy" version = "0.960" description = "Optional static typing for Python" -category = "main" +category = "dev" optional = false python-versions = ">=3.6" @@ -608,7 +619,7 @@ reports = ["lxml"] name = "mypy-extensions" version = "0.4.3" description = "Experimental type system extensions for programs checked with the mypy typechecker." -category = "main" +category = "dev" optional = false python-versions = "*" @@ -916,7 +927,7 @@ python-versions = ">=3.6" [[package]] name = "sqlalchemy" -version = "1.4.37" +version = "1.4.39" description = "Database Abstraction Library" category = "main" optional = false @@ -924,8 +935,6 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" [package.dependencies] greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} -mypy = {version = ">=0.910", optional = true, markers = "python_version >= \"3\" and extra == \"mypy\""} -sqlalchemy2-stubs = {version = "*", optional = true, markers = "extra == \"mypy\""} [package.extras] aiomysql = ["greenlet (!=0.4.17)", "aiomysql"] @@ -950,7 +959,7 @@ sqlcipher = ["sqlcipher3-binary"] [[package]] name = "sqlalchemy2-stubs" -version = "0.0.2a23" +version = "0.0.2a24" description = "Typing Stubs for SQLAlchemy 1.4" category = "main" optional = false @@ -1134,9 +1143,13 @@ dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"] [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "e8f20d21a8c7822fbc3c183376d694fc0109e90851377bc6b7316c5c72e880b0" +content-hash = "19151bbc858317aec5747a8f45a86b47cc198111422cc166a94634ad1941d8bc" [metadata.files] +aiosqlite = [ + {file = "aiosqlite-0.17.0-py3-none-any.whl", hash = "sha256:6c49dc6d3405929b1d08eeccc72306d3677503cc5e5e43771efc1e00232e8231"}, + {file = "aiosqlite-0.17.0.tar.gz", hash = "sha256:f0e6acc24bc4864149267ac82fb46dfb3be4455f99fe21df82609cc6e6baee51"}, +] alembic = [ {file = "alembic-1.8.0-py3-none-any.whl", hash = "sha256:b5ae4bbfc7d1302ed413989d39474d102e7cfa158f6d5969d2497955ffe85a30"}, {file = "alembic-1.8.0.tar.gz", hash = "sha256:a2d4d90da70b30e70352cd9455e35873a255a31402a438fe24815758d7a0e5e1"}, @@ -1846,46 +1859,46 @@ soupsieve = [ {file = "soupsieve-2.3.2.post1.tar.gz", hash = "sha256:fc53893b3da2c33de295667a0e19f078c14bf86544af307354de5fcf12a3f30d"}, ] sqlalchemy = [ - {file = "SQLAlchemy-1.4.37-cp27-cp27m-macosx_10_14_x86_64.whl", hash = "sha256:d9050b0c4a7f5538650c74aaba5c80cd64450e41c206f43ea6d194ae6d060ff9"}, - {file = "SQLAlchemy-1.4.37-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b4c92823889cf9846b972ee6db30c0e3a92c0ddfc76c6060a6cda467aa5fb694"}, - {file = "SQLAlchemy-1.4.37-cp27-cp27m-win32.whl", hash = "sha256:b55932fd0e81b43f4aff397c8ad0b3c038f540af37930423ab8f47a20b117e4c"}, - {file = "SQLAlchemy-1.4.37-cp27-cp27m-win_amd64.whl", hash = "sha256:4a17c1a1152ca4c29d992714aa9df3054da3af1598e02134f2e7314a32ef69d8"}, - {file = "SQLAlchemy-1.4.37-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:ffe487570f47536b96eff5ef2b84034a8ba4e19aab5ab7647e677d94a119ea55"}, - {file = "SQLAlchemy-1.4.37-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:78363f400fbda80f866e8e91d37d36fe6313ff847ded08674e272873c1377ea5"}, - {file = "SQLAlchemy-1.4.37-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7ee34c85cbda7779d66abac392c306ec78c13f5c73a1f01b8b767916d4895d23"}, - {file = "SQLAlchemy-1.4.37-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8b38e088659b30c2ca0af63e5d139fad1779a7925d75075a08717a21c406c0f6"}, - {file = "SQLAlchemy-1.4.37-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6629c79967a6c92e33fad811599adf9bc5cee6e504a1027bbf9cc1b6fb2d276d"}, - {file = "SQLAlchemy-1.4.37-cp310-cp310-win32.whl", hash = "sha256:2aac2a685feb9882d09f457f4e5586c885d578af4e97a2b759e91e8c457cbce5"}, - {file = "SQLAlchemy-1.4.37-cp310-cp310-win_amd64.whl", hash = "sha256:7a44683cf97744a405103ef8fdd31199e9d7fc41b4a67e9044523b29541662b0"}, - {file = "SQLAlchemy-1.4.37-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:cffc67cdd07f0e109a1fc83e333972ae423ea5ad414585b63275b66b870ea62b"}, - {file = "SQLAlchemy-1.4.37-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17417327b87a0f703c9a20180f75e953315207d048159aff51822052f3e33e69"}, - {file = "SQLAlchemy-1.4.37-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:aaa0e90e527066409c2ea5676282cf4afb4a40bb9dce0f56c8ec2768bff22a6e"}, - {file = "SQLAlchemy-1.4.37-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c1d9fb3931e27d59166bb5c4dcc911400fee51082cfba66ceb19ac954ade068"}, - {file = "SQLAlchemy-1.4.37-cp36-cp36m-win32.whl", hash = "sha256:0e7fd52e48e933771f177c2a1a484b06ea03774fc7741651ebdf19985a34037c"}, - {file = "SQLAlchemy-1.4.37-cp36-cp36m-win_amd64.whl", hash = "sha256:eec39a17bab3f69c44c9df4e0ed87c7306f2d2bf1eca3070af644927ec4199fa"}, - {file = "SQLAlchemy-1.4.37-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:caca6acf3f90893d7712ae2c6616ecfeac3581b4cc677c928a330ce6fbad4319"}, - {file = "SQLAlchemy-1.4.37-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50c8eaf44c3fed5ba6758d375de25f163e46137c39fda3a72b9ee1d1bb327dfc"}, - {file = "SQLAlchemy-1.4.37-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:139c50b9384e6d32a74fc4dcd0e9717f343ed38f95dbacf832c782c68e3862f3"}, - {file = "SQLAlchemy-1.4.37-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d4c3b009c9220ae6e33f17b45f43fb46b9a1d281d76118405af13e26376f2e11"}, - {file = "SQLAlchemy-1.4.37-cp37-cp37m-win32.whl", hash = "sha256:9785d6f962d2c925aeb06a7539ac9d16608877da6aeaaf341984b3693ae80a02"}, - {file = "SQLAlchemy-1.4.37-cp37-cp37m-win_amd64.whl", hash = "sha256:3197441772dc3b1c6419f13304402f2418a18d7fe78000aa5a026e7100836739"}, - {file = "SQLAlchemy-1.4.37-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:3862a069a24f354145e01a76c7c720c263d62405fe5bed038c46a7ce900f5dd6"}, - {file = "SQLAlchemy-1.4.37-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e8706919829d455a9fa687c6bbd1b048e36fec3919a59f2d366247c2bfdbd9c"}, - {file = "SQLAlchemy-1.4.37-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:06ec11a5e6a4b6428167d3ce33b5bd455c020c867dabe3e6951fa98836e0741d"}, - {file = "SQLAlchemy-1.4.37-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d58f2d9d1a4b1459e8956a0153a4119da80f54ee5a9ea623cd568e99459a3ef1"}, - {file = "SQLAlchemy-1.4.37-cp38-cp38-win32.whl", hash = "sha256:d6927c9e3965b194acf75c8e0fb270b4d54512db171f65faae15ef418721996e"}, - {file = "SQLAlchemy-1.4.37-cp38-cp38-win_amd64.whl", hash = "sha256:a91d0668cada27352432f15b92ac3d43e34d8f30973fa8b86f5e9fddee928f3b"}, - {file = "SQLAlchemy-1.4.37-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:f9940528bf9c4df9e3c3872d23078b6b2da6431c19565637c09f1b88a427a684"}, - {file = "SQLAlchemy-1.4.37-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29a742c29fea12259f1d2a9ee2eb7fe4694a85d904a4ac66d15e01177b17ad7f"}, - {file = "SQLAlchemy-1.4.37-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7e579d6e281cc937bdb59917017ab98e618502067e04efb1d24ac168925e1d2a"}, - {file = "SQLAlchemy-1.4.37-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a940c551cfbd2e1e646ceea2777944425f5c3edff914bc808fe734d9e66f8d71"}, - {file = "SQLAlchemy-1.4.37-cp39-cp39-win32.whl", hash = "sha256:5e4e517ce72fad35cce364a01aff165f524449e9c959f1837dc71088afa2824c"}, - {file = "SQLAlchemy-1.4.37-cp39-cp39-win_amd64.whl", hash = "sha256:c37885f83b59e248bebe2b35beabfbea398cb40960cdc6d3a76eac863d4e1938"}, - {file = "SQLAlchemy-1.4.37.tar.gz", hash = "sha256:3688f92c62db6c5df268e2264891078f17ecb91e3141b400f2e28d0f75796dea"}, + {file = "SQLAlchemy-1.4.39-cp27-cp27m-macosx_10_14_x86_64.whl", hash = "sha256:4770eb3ba69ec5fa41c681a75e53e0e342ac24c1f9220d883458b5596888e43a"}, + {file = "SQLAlchemy-1.4.39-cp27-cp27m-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:752ef2e8dbaa3c5d419f322e3632f00ba6b1c3230f65bc97c2ff5c5c6c08f441"}, + {file = "SQLAlchemy-1.4.39-cp27-cp27m-win32.whl", hash = "sha256:b30e70f1594ee3c8902978fd71900d7312453922827c4ce0012fa6a8278d6df4"}, + {file = "SQLAlchemy-1.4.39-cp27-cp27m-win_amd64.whl", hash = "sha256:864d4f89f054819cb95e93100b7d251e4d114d1c60bc7576db07b046432af280"}, + {file = "SQLAlchemy-1.4.39-cp27-cp27mu-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:8f901be74f00a13bf375241a778455ee864c2c21c79154aad196b7a994e1144f"}, + {file = "SQLAlchemy-1.4.39-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:1745987ada1890b0e7978abdb22c133eca2e89ab98dc17939042240063e1ef21"}, + {file = "SQLAlchemy-1.4.39-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ede13a472caa85a13abe5095e71676af985d7690eaa8461aeac5c74f6600b6c0"}, + {file = "SQLAlchemy-1.4.39-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7f13644b15665f7322f9e0635129e0ef2098409484df67fcd225d954c5861559"}, + {file = "SQLAlchemy-1.4.39-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26146c59576dfe9c546c9f45397a7c7c4a90c25679492ff610a7500afc7d03a6"}, + {file = "SQLAlchemy-1.4.39-cp310-cp310-win32.whl", hash = "sha256:91d2b89bb0c302f89e753bea008936acfa4e18c156fb264fe41eb6bbb2bbcdeb"}, + {file = "SQLAlchemy-1.4.39-cp310-cp310-win_amd64.whl", hash = "sha256:50e7569637e2e02253295527ff34666706dbb2bc5f6c61a5a7f44b9610c9bb09"}, + {file = "SQLAlchemy-1.4.39-cp36-cp36m-macosx_10_14_x86_64.whl", hash = "sha256:107df519eb33d7f8e0d0d052128af2f25066c1a0f6b648fd1a9612ab66800b86"}, + {file = "SQLAlchemy-1.4.39-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f24d4d6ec301688c59b0c4bb1c1c94c5d0bff4ecad33bb8f5d9efdfb8d8bc925"}, + {file = "SQLAlchemy-1.4.39-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7b2785dd2a0c044a36836857ac27310dc7a99166253551ee8f5408930958cc60"}, + {file = "SQLAlchemy-1.4.39-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e6e2c8581c6620136b9530137954a8376efffd57fe19802182c7561b0ab48b48"}, + {file = "SQLAlchemy-1.4.39-cp36-cp36m-win32.whl", hash = "sha256:fbc076f79d830ae4c9d49926180a1140b49fa675d0f0d555b44c9a15b29f4c80"}, + {file = "SQLAlchemy-1.4.39-cp36-cp36m-win_amd64.whl", hash = "sha256:0ec54460475f0c42512895c99c63d90dd2d9cbd0c13491a184182e85074b04c5"}, + {file = "SQLAlchemy-1.4.39-cp37-cp37m-macosx_10_15_x86_64.whl", hash = "sha256:6f95706da857e6e79b54c33c1214f5467aab10600aa508ddd1239d5df271986e"}, + {file = "SQLAlchemy-1.4.39-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:621f050e72cc7dfd9ad4594ff0abeaad954d6e4a2891545e8f1a53dcdfbef445"}, + {file = "SQLAlchemy-1.4.39-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:05a05771617bfa723ba4cef58d5b25ac028b0d68f28f403edebed5b8243b3a87"}, + {file = "SQLAlchemy-1.4.39-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:20bf65bcce65c538e68d5df27402b39341fabeecf01de7e0e72b9d9836c13c52"}, + {file = "SQLAlchemy-1.4.39-cp37-cp37m-win32.whl", hash = "sha256:f2a42acc01568b9701665e85562bbff78ec3e21981c7d51d56717c22e5d3d58b"}, + {file = "SQLAlchemy-1.4.39-cp37-cp37m-win_amd64.whl", hash = "sha256:6d81de54e45f1d756785405c9d06cd17918c2eecc2d4262dc2d276ca612c2f61"}, + {file = "SQLAlchemy-1.4.39-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:5c2d19bfb33262bf987ef0062345efd0f54c4189c2d95159c72995457bf4a359"}, + {file = "SQLAlchemy-1.4.39-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:14ea8ff2d33c48f8e6c3c472111d893b9e356284d1482102da9678195e5a8eac"}, + {file = "SQLAlchemy-1.4.39-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ec3985c883d6d217cf2013028afc6e3c82b8907192ba6195d6e49885bfc4b19d"}, + {file = "SQLAlchemy-1.4.39-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1962dfee37b7fb17d3d4889bf84c4ea08b1c36707194c578f61e6e06d12ab90f"}, + {file = "SQLAlchemy-1.4.39-cp38-cp38-win32.whl", hash = "sha256:047ef5ccd8860f6147b8ac6c45a4bc573d4e030267b45d9a1c47b55962ff0e6f"}, + {file = "SQLAlchemy-1.4.39-cp38-cp38-win_amd64.whl", hash = "sha256:b71be98ef6e180217d1797185c75507060a57ab9cd835653e0112db16a710f0d"}, + {file = "SQLAlchemy-1.4.39-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:365b75938049ae31cf2176efd3d598213ddb9eb883fbc82086efa019a5f649df"}, + {file = "SQLAlchemy-1.4.39-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e7a7667d928ba6ee361a3176e1bef6847c1062b37726b33505cc84136f657e0d"}, + {file = "SQLAlchemy-1.4.39-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:c6d00cb9da8d0cbfaba18cad046e94b06de6d4d0ffd9d4095a3ad1838af22528"}, + {file = "SQLAlchemy-1.4.39-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0538b66f959771c56ff996d828081908a6a52a47c5548faed4a3d0a027a5368"}, + {file = "SQLAlchemy-1.4.39-cp39-cp39-win32.whl", hash = "sha256:d1f665e50592caf4cad3caed3ed86f93227bffe0680218ccbb293bd5a6734ca8"}, + {file = "SQLAlchemy-1.4.39-cp39-cp39-win_amd64.whl", hash = "sha256:8b773c9974c272aae0fa7e95b576d98d17ee65f69d8644f9b6ffc90ee96b4d19"}, + {file = "SQLAlchemy-1.4.39.tar.gz", hash = "sha256:8194896038753b46b08a0b0ae89a5d80c897fb601dd51e243ed5720f1f155d27"}, ] sqlalchemy2-stubs = [ - {file = "sqlalchemy2-stubs-0.0.2a23.tar.gz", hash = "sha256:a13d94e23b5b0da8ee21986ef8890788a1f2eb26c2a9f39424cc933e4e7e87ff"}, - {file = "sqlalchemy2_stubs-0.0.2a23-py3-none-any.whl", hash = "sha256:6011d2219365d4e51f3e9d83ffeb5b904964ef1d143dc1298d8a70ce8641014d"}, + {file = "sqlalchemy2-stubs-0.0.2a24.tar.gz", hash = "sha256:e15c45302eafe196ed516f979ef017135fd619d2c62d02de9a5c5f2e59a600c4"}, + {file = "sqlalchemy2_stubs-0.0.2a24-py3-none-any.whl", hash = "sha256:f2399251d3d8f00a88659d711a449c855a0d4e977c7a9134e414f1459b9acc11"}, ] starlette = [ {file = "starlette-0.19.1-py3-none-any.whl", hash = "sha256:5a60c5c2d051f3a8eb546136aa0c9399773a689595e099e0877704d5888279bf"}, diff --git a/pyproject.toml b/pyproject.toml index 98aa982..9e43a3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ python-multipart = "^0.0.5" tomli = "^2.0.1" httpx = "^0.23.0" timeago = "^1.0.15" -SQLAlchemy = {extras = ["mypy"], version = "^1.4.37"} +SQLAlchemy = {extras = ["asyncio"], version = "^1.4.39"} alembic = "^1.8.0" bleach = "^5.0.0" requests = "^2.27.1" @@ -38,6 +38,8 @@ html2text = "^2020.1.16" feedgen = "^0.9.0" emoji = "^1.7.0" PyLD = "^2.0.3" +aiosqlite = "^0.17.0" +sqlalchemy2-stubs = "^0.0.2-alpha.24" [tool.poetry.dev-dependencies] black = "^22.3.0" diff --git a/tests/conftest.py b/tests/conftest.py index f175631..21a8039 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,22 +2,25 @@ from typing import Generator import pytest from fastapi.testclient import TestClient -from sqlalchemy import orm from app.database import Base +from app.database import async_engine +from app.database import async_session from app.database import engine -from app.database import get_db from app.main import app from tests.factories import _Session # _Session = orm.sessionmaker(bind=engine, autocommit=False, autoflush=False) -def _get_db_for_testing() -> Generator[orm.Session, None, None]: - # try: - yield _Session # type: ignore - # finally: - # session.close() +@pytest.fixture +async def async_db_session(): + async with async_session() as session: + async with async_engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield session + async with async_engine.begin() as conn: + await conn.run_sync(Base.metadata.drop_all) @pytest.fixture @@ -46,6 +49,6 @@ def exclude_fastapi_middleware(): @pytest.fixture def client(db, exclude_fastapi_middleware) -> Generator: - app.dependency_overrides[get_db] = _get_db_for_testing + # app.dependency_overrides[get_db] = _get_db_for_testing with TestClient(app) as c: yield c diff --git a/tests/test_actor.py b/tests/test_actor.py index 801e406..c59f2c0 100644 --- a/tests/test_actor.py +++ b/tests/test_actor.py @@ -1,13 +1,18 @@ import httpx +import pytest import respx +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session from app import models from app.actor import fetch_actor -from app.database import Session +from app.database import AsyncSession from tests import factories -def test_fetch_actor(db: Session, respx_mock) -> None: +@pytest.mark.asyncio +async def test_fetch_actor(async_db_session: AsyncSession, respx_mock) -> None: # Given a remote actor ra = factories.RemoteActorFactory( base_url="https://example.com", @@ -17,18 +22,22 @@ def test_fetch_actor(db: Session, respx_mock) -> None: respx_mock.get(ra.ap_id).mock(return_value=httpx.Response(200, json=ra.ap_actor)) # When fetching this actor for the first time - saved_actor = fetch_actor(db, ra.ap_id) + saved_actor = await fetch_actor(async_db_session, ra.ap_id) # Then it has been fetched and saved in DB assert respx.calls.call_count == 1 - assert db.query(models.Actor).one().ap_id == saved_actor.ap_id + assert ( + await async_db_session.execute(select(models.Actor)) + ).scalar_one().ap_id == saved_actor.ap_id # When fetching it a second time - actor_from_db = fetch_actor(db, ra.ap_id) + actor_from_db = await fetch_actor(async_db_session, ra.ap_id) # Then it's read from the DB assert actor_from_db.ap_id == ra.ap_id - assert db.query(models.Actor).count() == 1 + assert ( + await async_db_session.execute(select(func.count(models.Actor.id))) + ).scalar_one() == 1 assert respx.calls.call_count == 1 diff --git a/tests/test_emoji.py b/tests/test_emoji.py index f5bb621..b459ff3 100644 --- a/tests/test_emoji.py +++ b/tests/test_emoji.py @@ -1,9 +1,9 @@ from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from app import activitypub as ap from app import models from app.config import generate_csrf_token -from app.database import Session from app.utils.emoji import EMOJIS_BY_NAME from tests.utils import generate_admin_session_cookies diff --git a/tests/test_inbox.py b/tests/test_inbox.py index 736f6f3..6f134fa 100644 --- a/tests/test_inbox.py +++ b/tests/test_inbox.py @@ -3,12 +3,12 @@ from uuid import uuid4 import httpx import respx from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from app import activitypub as ap from app import models from app.actor import LOCAL_ACTOR from app.ap_object import RemoteObject -from app.database import Session from tests import factories from tests.utils import mock_httpsig_checker diff --git a/tests/test_outbox.py b/tests/test_outbox.py index 7a978a1..45bdda8 100644 --- a/tests/test_outbox.py +++ b/tests/test_outbox.py @@ -4,6 +4,7 @@ from uuid import uuid4 import httpx import respx from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from app import activitypub as ap from app import models @@ -11,7 +12,6 @@ from app import webfinger from app.actor import LOCAL_ACTOR from app.ap_object import RemoteObject from app.config import generate_csrf_token -from app.database import Session from tests import factories from tests.utils import generate_admin_session_cookies diff --git a/tests/test_process_outgoing_activities.py b/tests/test_process_outgoing_activities.py index a5e5422..f013518 100644 --- a/tests/test_process_outgoing_activities.py +++ b/tests/test_process_outgoing_activities.py @@ -1,13 +1,16 @@ from uuid import uuid4 import httpx +import pytest import respx from fastapi.testclient import TestClient +from sqlalchemy import select +from sqlalchemy.orm import Session from app import models from app.actor import LOCAL_ACTOR from app.ap_object import RemoteObject -from app.database import Session +from app.database import AsyncSession from app.outgoing_activities import _MAX_RETRIES from app.outgoing_activities import new_outgoing_activity from app.outgoing_activities import process_next_outgoing_activity @@ -36,8 +39,9 @@ def _setup_outbox_object() -> models.OutboxObject: return outbox_object -def test_new_outgoing_activity( - db: Session, +@pytest.mark.asyncio +async def test_new_outgoing_activity( + async_db_session: AsyncSession, client: TestClient, respx_mock: respx.MockRouter, ) -> None: @@ -48,9 +52,13 @@ def test_new_outgoing_activity( raise ValueError("Should never happen") # When queuing the activity - outgoing_activity = new_outgoing_activity(db, inbox_url, outbox_object.id) + outgoing_activity = await new_outgoing_activity( + async_db_session, inbox_url, outbox_object.id + ) - assert db.query(models.OutgoingActivity).one() == outgoing_activity + assert ( + await async_db_session.execute(select(models.OutgoingActivity)) + ).scalar_one() == outgoing_activity assert outgoing_activity.outbox_object_id == outbox_object.id assert outgoing_activity.recipient == inbox_url diff --git a/tests/test_public.py b/tests/test_public.py index 0ea60c8..083d518 100644 --- a/tests/test_public.py +++ b/tests/test_public.py @@ -1,9 +1,9 @@ import pytest from fastapi.testclient import TestClient +from sqlalchemy.orm import Session from app import activitypub as ap from app.actor import LOCAL_ACTOR -from app.database import Session _ACCEPTED_AP_HEADERS = [ "application/activity+json",