mirror of
https://git.sr.ht/~tsileo/microblog.pub
synced 2024-12-22 13:14:28 +00:00
Tweak middleware
This commit is contained in:
parent
a39f874ad5
commit
d245201851
2 changed files with 11 additions and 35 deletions
|
@ -145,8 +145,7 @@ async def save_actor(db_session: AsyncSession, ap_actor: ap.RawObject) -> "Actor
|
||||||
handle=_handle(ap_actor),
|
handle=_handle(ap_actor),
|
||||||
)
|
)
|
||||||
db_session.add(actor)
|
db_session.add(actor)
|
||||||
await db_session.commit()
|
await db_session.flush()
|
||||||
await db_session.refresh(actor)
|
|
||||||
return actor
|
return actor
|
||||||
|
|
||||||
|
|
||||||
|
|
43
app/main.py
43
app/main.py
|
@ -92,54 +92,29 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
|
||||||
class CustomMiddleware:
|
class CustomMiddleware:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app: "ASGI3Application",
|
app: ASGI3Application,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.app = app
|
self.app = app
|
||||||
|
|
||||||
async def __call__(
|
async def __call__(
|
||||||
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
|
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
# We only care about HTTP requests
|
||||||
if scope["type"] in ("http", "websocket"):
|
|
||||||
scope = cast(HTTPScope | WebSocketScope, scope)
|
|
||||||
client_addr: tuple[str, int] | None = scope.get("client")
|
|
||||||
client_host = client_addr[0] if client_addr else None
|
|
||||||
|
|
||||||
if self.always_trust or client_host in self.trusted_hosts:
|
|
||||||
headers = dict(scope["headers"])
|
|
||||||
|
|
||||||
if b"x-forwarded-proto" in headers:
|
|
||||||
# Determine if the incoming request was http or https based on
|
|
||||||
# the X-Forwarded-Proto header.
|
|
||||||
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1")
|
|
||||||
scope["scheme"] = x_forwarded_proto.strip() # type: ignore[index]
|
|
||||||
|
|
||||||
if b"x-forwarded-for" in headers:
|
|
||||||
# Determine the client address from the last trusted IP in the
|
|
||||||
# X-Forwarded-For header. We've lost the connecting client's port
|
|
||||||
# information by now, so only include the host.
|
|
||||||
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
|
|
||||||
x_forwarded_for_hosts = [
|
|
||||||
item.strip() for item in x_forwarded_for.split(",")
|
|
||||||
]
|
|
||||||
host = self.get_trusted_client_host(x_forwarded_for_hosts)
|
|
||||||
port = 0
|
|
||||||
scope["client"] = (host, port) # type: ignore[arg-type]
|
|
||||||
"""
|
|
||||||
|
|
||||||
if scope["type"] != "http":
|
if scope["type"] != "http":
|
||||||
await self.app(scope, receive, send)
|
await self.app(scope, receive, send)
|
||||||
return
|
return
|
||||||
|
|
||||||
instance = {"http_status_code": None}
|
response_details = {}
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
request_id = os.urandom(8).hex()
|
request_id = os.urandom(8).hex()
|
||||||
|
|
||||||
async def send_wrapper(message: Message) -> None:
|
async def send_wrapper(message: Message) -> None:
|
||||||
if message["type"] == "http.response.start":
|
if message["type"] == "http.response.start":
|
||||||
instance["http_status_code"] = message["status"]
|
|
||||||
|
|
||||||
|
# Extract the HTTP response status code
|
||||||
|
response_details["status_code"] = message["status"]
|
||||||
|
|
||||||
|
# And add the security headers
|
||||||
headers = MutableHeaders(scope=message)
|
headers = MutableHeaders(scope=message)
|
||||||
headers["X-Request-ID"] = request_id
|
headers["X-Request-ID"] = request_id
|
||||||
headers["Server"] = "microblogpub"
|
headers["Server"] = "microblogpub"
|
||||||
|
@ -160,6 +135,8 @@ class CustomMiddleware:
|
||||||
|
|
||||||
await send(message) # type: ignore
|
await send(message) # type: ignore
|
||||||
|
|
||||||
|
# Make loguru ouput the request ID on every log statement within
|
||||||
|
# the request
|
||||||
with logger.contextualize(request_id=request_id):
|
with logger.contextualize(request_id=request_id):
|
||||||
client_host, client_port = scope["client"] # type: ignore
|
client_host, client_port = scope["client"] # type: ignore
|
||||||
scheme = scope["scheme"]
|
scheme = scope["scheme"]
|
||||||
|
@ -175,7 +152,7 @@ class CustomMiddleware:
|
||||||
finally:
|
finally:
|
||||||
elapsed_time = time.perf_counter() - start_time
|
elapsed_time = time.perf_counter() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"status_code={instance['http_status_code']} "
|
f"status_code={response_details['status_code']} "
|
||||||
f"{elapsed_time=:.2f}s"
|
f"{elapsed_time=:.2f}s"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue