diff --git a/litestar/middleware/_internal/exceptions/middleware.py b/litestar/middleware/_internal/exceptions/middleware.py index 3433a979b4..c11c28543e 100644 --- a/litestar/middleware/_internal/exceptions/middleware.py +++ b/litestar/middleware/_internal/exceptions/middleware.py @@ -23,6 +23,7 @@ from litestar import Response from litestar.app import Litestar from litestar.connection import Request + from litestar.handlers import BaseRouteHandler from litestar.logging import BaseLoggingConfig from litestar.types import ( ASGIApp, @@ -202,7 +203,11 @@ async def handle_request_exception( exception_handler = get_exception_handler(exception_handlers, exc) or self.default_http_exception_handler request: Request[Any, Any, Any] = litestar_app.request_class(scope=scope, receive=receive, send=send) response = exception_handler(request, exc) - await response.to_asgi_response(app=None, request=request)(scope=scope, receive=receive, send=send) + route_handler: BaseRouteHandler | None = scope.get("route_handler") + type_encoders = route_handler.resolve_type_encoders() if route_handler else litestar_app.type_encoders + await response.to_asgi_response(app=None, request=request, type_encoders=type_encoders)( + scope=scope, receive=receive, send=send + ) @staticmethod async def handle_websocket_exception(send: Send, exc: Exception) -> None: diff --git a/tests/unit/test_middleware/test_exception_handler_middleware.py b/tests/unit/test_middleware/test_exception_handler_middleware.py index a1540d353c..81b402015e 100644 --- a/tests/unit/test_middleware/test_exception_handler_middleware.py +++ b/tests/unit/test_middleware/test_exception_handler_middleware.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Any, Callable, Generator, Optional from unittest.mock import MagicMock +import pydantic import pytest from pytest_mock import MockerFixture from starlette.exceptions import HTTPException as StarletteHTTPException @@ -146,6 +147,46 @@ def exception_handler(request: Request, exc: Exception) -> Response: } +def test_exception_handler_middleware_handler_response_type_encoding( + scope: HTTPScope, middleware: ExceptionHandlerMiddleware +) -> None: + class ErrorMessage(pydantic.BaseModel): + message: str + + @get("/") + def handler(_: Request) -> None: + raise Exception + + def exception_handler(_: Request, _e: Exception) -> Response: + return Response(content=ErrorMessage(message="the error message"), status_code=HTTP_500_INTERNAL_SERVER_ERROR) + + app = Litestar(route_handlers=[handler], exception_handlers={Exception: exception_handler}, openapi_config=None) + + with TestClient(app) as client: + response = client.get("/") + assert response.json() == {"message": "the error message"} + + +def test_exception_handler_middleware_handler_response_type_encoding_no_route_handler( + scope: HTTPScope, middleware: ExceptionHandlerMiddleware +) -> None: + class ErrorMessage(pydantic.BaseModel): + message: str + + @get("/") + def handler(_: Request) -> None: + raise Exception + + def exception_handler(_: Request, _e: Exception) -> Response: + return Response(content=ErrorMessage(message="the error message"), status_code=HTTP_500_INTERNAL_SERVER_ERROR) + + app = Litestar(route_handlers=[handler], exception_handlers={Exception: exception_handler}, openapi_config=None) + + with TestClient(app) as client: + response = client.get("/not-found") + assert response.json() == {"message": "the error message"} + + def test_exception_handler_middleware_calls_app_level_after_exception_hook() -> None: @get("/test") def handler() -> None: