From d245201851ad4385a696b4162ba74e9112e2cc14 Mon Sep 17 00:00:00 2001 From: Thomas Sileo Date: Thu, 14 Jul 2022 15:16:45 +0200 Subject: [PATCH] Tweak middleware --- app/actor.py | 3 +-- app/main.py | 43 ++++++++++--------------------------------- 2 files changed, 11 insertions(+), 35 deletions(-) diff --git a/app/actor.py b/app/actor.py index 7cc85eb..eb60c82 100644 --- a/app/actor.py +++ b/app/actor.py @@ -145,8 +145,7 @@ async def save_actor(db_session: AsyncSession, ap_actor: ap.RawObject) -> "Actor handle=_handle(ap_actor), ) db_session.add(actor) - await db_session.commit() - await db_session.refresh(actor) + await db_session.flush() return actor diff --git a/app/main.py b/app/main.py index 476b5db..ed684d8 100644 --- a/app/main.py +++ b/app/main.py @@ -92,54 +92,29 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac class CustomMiddleware: def __init__( self, - app: "ASGI3Application", + app: ASGI3Application, ) -> None: self.app = app async def __call__( self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable ) -> None: - """ - if scope["type"] in ("http", "websocket"): - scope = cast(HTTPScope | WebSocketScope, scope) - client_addr: tuple[str, int] | None = scope.get("client") - client_host = client_addr[0] if client_addr else None - - if self.always_trust or client_host in self.trusted_hosts: - headers = dict(scope["headers"]) - - if b"x-forwarded-proto" in headers: - # Determine if the incoming request was http or https based on - # the X-Forwarded-Proto header. - x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1") - scope["scheme"] = x_forwarded_proto.strip() # type: ignore[index] - - if b"x-forwarded-for" in headers: - # Determine the client address from the last trusted IP in the - # X-Forwarded-For header. We've lost the connecting client's port - # information by now, so only include the host. - x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1") - x_forwarded_for_hosts = [ - item.strip() for item in x_forwarded_for.split(",") - ] - host = self.get_trusted_client_host(x_forwarded_for_hosts) - port = 0 - scope["client"] = (host, port) # type: ignore[arg-type] - """ - + # We only care about HTTP requests if scope["type"] != "http": await self.app(scope, receive, send) return - instance = {"http_status_code": None} - + response_details = {} start_time = time.perf_counter() request_id = os.urandom(8).hex() async def send_wrapper(message: Message) -> None: if message["type"] == "http.response.start": - instance["http_status_code"] = message["status"] + # Extract the HTTP response status code + response_details["status_code"] = message["status"] + + # And add the security headers headers = MutableHeaders(scope=message) headers["X-Request-ID"] = request_id headers["Server"] = "microblogpub" @@ -160,6 +135,8 @@ class CustomMiddleware: await send(message) # type: ignore + # Make loguru ouput the request ID on every log statement within + # the request with logger.contextualize(request_id=request_id): client_host, client_port = scope["client"] # type: ignore scheme = scope["scheme"] @@ -175,7 +152,7 @@ class CustomMiddleware: finally: elapsed_time = time.perf_counter() - start_time logger.info( - f"status_code={instance['http_status_code']} " + f"status_code={response_details['status_code']} " f"{elapsed_time=:.2f}s" )