diff --git a/app/main.py b/app/main.py index e07d6f2..eca9b27 100644 --- a/app/main.py +++ b/app/main.py @@ -9,6 +9,7 @@ from typing import MutableMapping from typing import Type import httpx +import starlette from asgiref.typing import ASGI3Application from asgiref.typing import ASGIReceiveCallable from asgiref.typing import ASGISendCallable @@ -57,7 +58,6 @@ from app.config import DOMAIN from app.config import ID from app.config import USER_AGENT from app.config import USERNAME -from app.config import generate_csrf_token from app.config import is_activitypub_requested from app.config import verify_csrf_token from app.database import AsyncSession @@ -76,6 +76,7 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac # TODO(ts): # # Next: +# - Article support # - indieauth tweaks # - API for posting notes # - allow to block servers @@ -390,7 +391,6 @@ async def following( .all() ) - # TODO: support next_cursor/prev_cursor actors_metadata = {} if is_current_user_admin(request): actors_metadata = await get_actors_metadata( @@ -482,13 +482,17 @@ async def _check_outbox_object_acl( ap.VisibilityEnum.UNLISTED, ]: return None + elif ap_object.visibility == ap.VisibilityEnum.FOLLOWERS_ONLY: + # Is the signing actor a follower? followers = await boxes.fetch_actor_collection( db_session, BASE_URL + "/followers" ) if httpsig_info.signed_by_ap_actor_id in [actor.ap_id for actor in followers]: return None + elif ap_object.visibility == ap.VisibilityEnum.DIRECT: + # Is the signing actor targeted in the object audience? audience = ap_object.ap_object.get("to", []) + ap_object.ap_object.get("cc", []) if httpsig_info.signed_by_ap_actor_id in audience: return None @@ -718,7 +722,7 @@ async def get_remote_follow( db_session, request, "remote_follow.html", - {"remote_follow_csrf_token": generate_csrf_token()}, + {}, ) @@ -733,6 +737,7 @@ async def post_remote_follow( remote_follow_template = await get_remote_follow_template(profile) if not remote_follow_template: + # TODO(ts): error message to user raise HTTPException(status_code=404) return RedirectResponse( @@ -812,12 +817,9 @@ async def nodeinfo( proxy_client = httpx.AsyncClient(follow_redirects=True, http2=True) -@app.get("/proxy/media/{encoded_url}") -async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse: - # Decode the base64-encoded URL - url = base64.urlsafe_b64decode(encoded_url).decode() - check_url(url) - +async def _proxy_get( + request: starlette.requests.Request, url: str, stream: bool +) -> httpx.Response: # Request the URL (and filter request headers) proxy_req = proxy_client.build_request( request.method, @@ -830,27 +832,42 @@ async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResp ] + [(b"user-agent", USER_AGENT.encode())], ) - proxy_resp = await proxy_client.send(proxy_req, stream=True) - # Filter the headers - proxy_resp_headers = [ - (k, v) - for (k, v) in proxy_resp.headers.items() - if k.lower() - in [ - "content-length", - "content-type", - "content-range", - "accept-ranges" "etag", - "cache-control", - "expires", - "date", - "last-modified", - ] - ] + return await proxy_client.send(proxy_req, stream=stream) + + +def _filter_proxy_resp_headers( + proxy_resp: httpx.Response, + allowed_headers: list[str], +) -> dict[str, str]: + return { + k: v for (k, v) in proxy_resp.headers.items() if k.lower() in allowed_headers + } + + +@app.get("/proxy/media/{encoded_url}") +async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse: + # Decode the base64-encoded URL + url = base64.urlsafe_b64decode(encoded_url).decode() + check_url(url) + + proxy_resp = await _proxy_get(request, url, stream=True) + return StreamingResponse( proxy_resp.aiter_raw(), status_code=proxy_resp.status_code, - headers=dict(proxy_resp_headers), + headers=_filter_proxy_resp_headers( + proxy_resp, + [ + "content-length", + "content-type", + "content-range", + "accept-ranges" "etag", + "cache-control", + "expires", + "date", + "last-modified", + ], + ), background=BackgroundTask(proxy_resp.aclose), ) @@ -876,25 +893,7 @@ async def serve_proxy_media_resized( headers=resp_headers, ) - # Request the URL (and filter request headers) - async with httpx.AsyncClient() as client: - proxy_resp = await client.get( - url, - headers=[ - (k, v) - for (k, v) in request.headers.raw - if k.lower() - not in [ - b"host", - b"cookie", - b"x-forwarded-for", - b"x-real-ip", - b"user-agent", - ] - ] - + [(b"user-agent", USER_AGENT.encode())], - follow_redirects=True, - ) + proxy_resp = await _proxy_get(request, url, stream=False) if proxy_resp.status_code != 200: return PlainTextResponse( proxy_resp.content, @@ -902,18 +901,16 @@ async def serve_proxy_media_resized( ) # Filter the headers - proxy_resp_headers = { - k: v - for (k, v) in proxy_resp.headers.items() - if k.lower() - in [ + proxy_resp_headers = _filter_proxy_resp_headers( + proxy_resp, + [ "content-type", "etag", "cache-control", "expires", "last-modified", - ] - } + ], + ) try: out = BytesIO(proxy_resp.content) diff --git a/app/templates/remote_follow.html b/app/templates/remote_follow.html index b91c10e..bfe5ecc 100644 --- a/app/templates/remote_follow.html +++ b/app/templates/remote_follow.html @@ -11,9 +11,9 @@