Improve caching

This commit is contained in:
Thomas Sileo 2022-06-30 09:25:13 +02:00
parent d371e3cd4f
commit 6458d2a6c7
6 changed files with 87 additions and 24 deletions

View file

@ -8,22 +8,27 @@ import hashlib
import typing import typing
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from functools import lru_cache
from typing import Any from typing import Any
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
import fastapi import fastapi
import httpx import httpx
from cachetools import LFUCache
from Crypto.Hash import SHA256 from Crypto.Hash import SHA256
from Crypto.Signature import PKCS1_v1_5 from Crypto.Signature import PKCS1_v1_5
from loguru import logger from loguru import logger
from sqlalchemy import select
from app import activitypub as ap from app import activitypub as ap
from app import config from app import config
from app.database import AsyncSession
from app.database import get_db_session
from app.key import Key from app.key import Key
from app.key import get_key from app.key import get_key
_KEY_CACHE = LFUCache(256)
def _build_signed_string( def _build_signed_string(
signed_headers: str, method: str, path: str, headers: Any, body_digest: str | None signed_headers: str, method: str, path: str, headers: Any, body_digest: str | None
@ -62,9 +67,25 @@ def _body_digest(body: bytes) -> str:
return "SHA-256=" + base64.b64encode(h.digest()).decode("utf-8") return "SHA-256=" + base64.b64encode(h.digest()).decode("utf-8")
@lru_cache(32) async def _get_public_key(db_session: AsyncSession, key_id: str) -> Key:
async def _get_public_key(key_id: str) -> Key: if cached_key := _KEY_CACHE.get(key_id):
# TODO: use DB to use cache actor return cached_key
# Check if the key belongs to an actor already in DB
from app import models
existing_actor = (
await db_session.scalars(
select(models.Actor).where(models.Actor.ap_id == key_id.split("#")[0])
)
).one_or_none()
if existing_actor and existing_actor.public_key_id == key_id:
k = Key(existing_actor.ap_id, key_id)
k.load_pub(existing_actor.public_key_as_pem)
logger.info(f"Found {key_id} on an existing actor")
_KEY_CACHE[key_id] = k
return k
# Fetch it
from app import activitypub as ap from app import activitypub as ap
actor = await ap.fetch(key_id) actor = await ap.fetch(key_id)
@ -82,6 +103,7 @@ async def _get_public_key(key_id: str) -> Key:
f"failed to fetch requested key {key_id}: got {actor['publicKey']['id']}" f"failed to fetch requested key {key_id}: got {actor['publicKey']['id']}"
) )
_KEY_CACHE[key_id] = k
return k return k
@ -93,6 +115,7 @@ class HTTPSigInfo:
async def httpsig_checker( async def httpsig_checker(
request: fastapi.Request, request: fastapi.Request,
db_session: AsyncSession = fastapi.Depends(get_db_session),
) -> HTTPSigInfo: ) -> HTTPSigInfo:
body = await request.body() body = await request.body()
@ -111,7 +134,7 @@ async def httpsig_checker(
) )
try: try:
k = await _get_public_key(hsig["keyId"]) k = await _get_public_key(db_session, hsig["keyId"])
except ap.ObjectIsGoneError: except ap.ObjectIsGoneError:
logger.info("Actor is gone") logger.info("Actor is gone")
return HTTPSigInfo(has_valid_signature=False) return HTTPSigInfo(has_valid_signature=False)

View file

@ -8,6 +8,7 @@ from typing import Any
from typing import Type from typing import Type
import httpx import httpx
from cachetools import LFUCache
from fastapi import Depends from fastapi import Depends
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import Form from fastapi import Form
@ -56,6 +57,9 @@ from app.utils import pagination
from app.utils.emoji import EMOJIS_BY_NAME from app.utils.emoji import EMOJIS_BY_NAME
from app.webfinger import get_remote_follow_template from app.webfinger import get_remote_follow_template
_RESIZED_CACHE = LFUCache(32)
# TODO(ts): # TODO(ts):
# #
# Next: # Next:
@ -728,7 +732,7 @@ async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResp
@app.get("/proxy/media/{encoded_url}/{size}") @app.get("/proxy/media/{encoded_url}/{size}")
def serve_proxy_media_resized( async def serve_proxy_media_resized(
request: Request, request: Request,
encoded_url: str, encoded_url: str,
size: int, size: int,
@ -738,18 +742,38 @@ def serve_proxy_media_resized(
# Decode the base64-encoded URL # Decode the base64-encoded URL
url = base64.urlsafe_b64decode(encoded_url).decode() url = base64.urlsafe_b64decode(encoded_url).decode()
is_cached = False
is_resized = False
if cached_resp := _RESIZED_CACHE.get((url, size)):
is_resized, resized_content, resized_mimetype, resp_headers = cached_resp
if is_resized:
return PlainTextResponse(
resized_content,
media_type=resized_mimetype,
headers=resp_headers,
)
is_cached = True
# Request the URL (and filter request headers) # Request the URL (and filter request headers)
proxy_resp = httpx.get( async with httpx.AsyncClient() as client:
proxy_resp = await client.get(
url, url,
headers=[ headers=[
(k, v) (k, v)
for (k, v) in request.headers.raw for (k, v) in request.headers.raw
if k.lower() if k.lower()
not in [b"host", b"cookie", b"x-forwarded-for", b"x-real-ip", b"user-agent"] not in [
b"host",
b"cookie",
b"x-forwarded-for",
b"x-real-ip",
b"user-agent",
]
] ]
+ [(b"user-agent", USER_AGENT.encode())], + [(b"user-agent", USER_AGENT.encode())],
) )
if proxy_resp.status_code != 200: if proxy_resp.status_code != 200 or (is_cached and not is_resized):
return PlainTextResponse( return PlainTextResponse(
proxy_resp.content, proxy_resp.content,
status_code=proxy_resp.status_code, status_code=proxy_resp.status_code,
@ -772,15 +796,23 @@ def serve_proxy_media_resized(
try: try:
out = BytesIO(proxy_resp.content) out = BytesIO(proxy_resp.content)
i = Image.open(out) i = Image.open(out)
if i.is_animated: if getattr(i, "is_animated", False):
raise ValueError raise ValueError
i.thumbnail((size, size)) i.thumbnail((size, size))
resized_buf = BytesIO() resized_buf = BytesIO()
i.save(resized_buf, format=i.format) i.save(resized_buf, format=i.format)
resized_buf.seek(0) resized_buf.seek(0)
resized_content = resized_buf.read()
resized_mimetype = i.get_format_mimetype() # type: ignore
_RESIZED_CACHE[(url, size)] = (
True,
resized_content,
resized_mimetype,
proxy_resp_headers,
)
return PlainTextResponse( return PlainTextResponse(
resized_buf.read(), resized_content,
media_type=i.get_format_mimetype(), # type: ignore media_type=resized_mimetype,
headers=proxy_resp_headers, headers=proxy_resp_headers,
) )
except ValueError: except ValueError:

View file

@ -190,7 +190,8 @@ def _clean_html(html: str, note: Object) -> str:
strip=True, strip=True,
), ),
note, note,
) ),
is_local=note.ap_id.startswith(BASE_URL),
) )
except Exception: except Exception:
raise raise
@ -241,12 +242,15 @@ def _html2text(content: str) -> str:
return H2T.handle(content) return H2T.handle(content)
def _replace_emoji(u, data): def _replace_emoji(u: str, _) -> str:
filename = hex(ord(u))[2:] filename = hex(ord(u))[2:]
return config.EMOJI_TPL.format(filename=filename, raw=u) return config.EMOJI_TPL.format(filename=filename, raw=u)
def _emojify(text: str): def _emojify(text: str, is_local: bool) -> str:
if not is_local:
return text
return emoji.replace_emoji( return emoji.replace_emoji(
text, text,
replace=_replace_emoji, replace=_replace_emoji,

View file

@ -16,7 +16,10 @@
</div> </div>
{{ utils.display_actor(inbox_object.actor, actors_metadata) }} {{ utils.display_actor(inbox_object.actor, actors_metadata) }}
{% else %} {% else %}
<p>
Implement {{ inbox_object.ap_type }} Implement {{ inbox_object.ap_type }}
{{ inbox_object.ap_object }}
</p>
{% endif %} {% endif %}
{% endfor %} {% endfor %}

2
poetry.lock generated
View file

@ -1143,7 +1143,7 @@ dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "19151bbc858317aec5747a8f45a86b47cc198111422cc166a94634ad1941d8bc" content-hash = "91e35a13d21bb5fd3e8916aee95c0a8019bec3cf4f0c677bb86641f1d88dcfe3"
[metadata.files] [metadata.files]
aiosqlite = [ aiosqlite = [

View file

@ -40,6 +40,7 @@ emoji = "^1.7.0"
PyLD = "^2.0.3" PyLD = "^2.0.3"
aiosqlite = "^0.17.0" aiosqlite = "^0.17.0"
sqlalchemy2-stubs = "^0.0.2-alpha.24" sqlalchemy2-stubs = "^0.0.2-alpha.24"
cachetools = "^5.2.0"
[tool.poetry.dev-dependencies] [tool.poetry.dev-dependencies]
black = "^22.3.0" black = "^22.3.0"