diff --git a/requirements.txt b/requirements.txt index 6d044c13d..65e240832 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,8 @@ types-contextvars==2.4.7.2 types-PyYAML==6.0.12.10 types-dataclasses==0.6.6 pytest==7.4.0 -trio==0.21.0 +trio==0.22.1 +anyio@git+/~https://github.com/agronholm/anyio.git # Documentation mkdocs==1.4.3 diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 170a805a7..ee99ee6cb 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,12 +1,18 @@ +import sys import typing +from contextlib import contextmanager import anyio +from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette.background import BackgroundTask from starlette.requests import ClientDisconnect, Request from starlette.responses import ContentStream, Response, StreamingResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] @@ -14,6 +20,17 @@ T = typing.TypeVar("T") +@contextmanager +def _convert_excgroups() -> typing.Generator[None, None, None]: + try: + yield + except BaseException as exc: + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] + + raise exc + + class _CachedRequest(Request): """ If the user calls Request.body() from their dispatch function @@ -107,6 +124,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None + send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] send_stream, recv_stream = anyio.create_memory_object_stream() async def receive_or_disconnect() -> Message: @@ -182,10 +201,11 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response.raw_headers = message["headers"] return response - async with anyio.create_task_group() as task_group: - response = await self.dispatch_func(request, call_next) - await response(scope, wrapped_receive, send) - response_sent.set() + with _convert_excgroups(): + async with anyio.create_task_group() as task_group: + response = await self.dispatch_func(request, call_next) + await response(scope, wrapped_receive, send) + response_sent.set() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 9dbd06528..d4a117cac 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -5,6 +5,7 @@ import warnings import anyio +from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette.types import Receive, Scope, Send @@ -72,6 +73,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class WSGIResponder: + stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + def __init__(self, app: typing.Callable, scope: Scope) -> None: self.app = app self.scope = scope diff --git a/starlette/testclient.py b/starlette/testclient.py index a91ad7bfc..c9ae97a08 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -12,6 +12,7 @@ import anyio import anyio.from_thread +from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable @@ -737,12 +738,18 @@ def __enter__(self) -> "TestClient": def reset_portal() -> None: self.portal = None - self.stream_send = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) - ) - self.stream_receive = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) - ) + send1: ObjectSendStream[ + typing.Optional[typing.MutableMapping[str, typing.Any]] + ] + receive1: ObjectReceiveStream[ + typing.Optional[typing.MutableMapping[str, typing.Any]] + ] + send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + send1, receive1 = anyio.create_memory_object_stream(math.inf) + send2, receive2 = anyio.create_memory_object_stream(math.inf) + self.stream_send = StapledObjectStream(send1, receive1) + self.stream_receive = StapledObjectStream(send2, receive2) self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index bcb4cd6ff..ad3975403 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -4,6 +4,9 @@ from starlette.middleware.wsgi import WSGIMiddleware, build_environ +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import ExceptionGroup + def hello_world(environ, start_response): status = "200 OK" @@ -66,9 +69,12 @@ def test_wsgi_exception(test_client_factory): # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(ExceptionGroup) as exc: client.get("/") + assert len(exc.value.exceptions) == 1 + assert isinstance(exc.value.exceptions[0], RuntimeError) + def test_wsgi_exc_info(test_client_factory): # Note that we're testing the WSGI app directly here. diff --git a/tests/test_websockets.py b/tests/test_websockets.py index c1ec1153e..71bccd455 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,7 +1,9 @@ import sys +from typing import Any, MutableMapping import anyio import pytest +from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette import status from starlette.types import Receive, Scope, Send @@ -178,6 +180,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_websocket_concurrency_pattern(test_client_factory): + stream_send: ObjectSendStream[MutableMapping[str, Any]] + stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] stream_send, stream_receive = anyio.create_memory_object_stream() async def reader(websocket):