From 5c4169960d8248a43c9fbb895d7fd6d0d4e8d87a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 12 Jul 2023 19:35:32 +0300 Subject: [PATCH 01/22] Upgraded to AnyIO 4.0 and dropped Python 3.8 support --- requirements.txt | 3 ++- tests/middleware/test_base.py | 20 +++++++++++++++----- tests/middleware/test_wsgi.py | 5 ++++- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 88cd2a159..bdfc2588e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,8 @@ types-contextvars==2.4.7.2 types-PyYAML==6.0.12.10 types-dataclasses==0.6.6 pytest==7.3.1 -trio==0.21.0 +trio==0.22.1 +anyio@git+/~https://github.com/agronholm/anyio.git@detect-asyncio-native-cancel # Documentation mkdocs==1.4.3 diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index cf4780cce..c8e72291d 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -74,16 +74,26 @@ def test_custom_middleware(test_client_factory): response = client.get("/") assert response.headers["Custom-Header"] == "Example" - with pytest.raises(Exception) as ctx: + with pytest.raises(ExceptionGroup) as ctx: response = client.get("/exc") - assert str(ctx.value) == "Exc" + assert len(ctx.value.exceptions) == 1 + assert str(ctx.value.exceptions[0]) == "Exc" - with pytest.raises(Exception) as ctx: + with pytest.raises(ExceptionGroup) as ctx: response = client.get("/exc-stream") - assert str(ctx.value) == "Faulty Stream" + exc = ctx.value + while isinstance(exc, ExceptionGroup): + assert len(exc.exceptions) == 1 + exc = exc.exceptions[0] + assert str(exc) == "Faulty Stream" - with pytest.raises(RuntimeError): + with pytest.raises(ExceptionGroup) as ctx: response = client.get("/no-response") + exc = ctx.value + while isinstance(exc, ExceptionGroup): + assert len(exc.exceptions) == 1 + exc = exc.exceptions[0] + assert isinstance(exc, RuntimeError) with client.websocket_connect("/ws") as session: text = session.receive_text() diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index bcb4cd6ff..c05df3f7d 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -66,9 +66,12 @@ def test_wsgi_exception(test_client_factory): # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) - with pytest.raises(RuntimeError): + with pytest.raises(ExceptionGroup) as exc: client.get("/") + assert len(exc.value.exceptions) == 1 + assert isinstance(exc.value.exceptions[0], RuntimeError) + def test_wsgi_exc_info(test_client_factory): # Note that we're testing the WSGI app directly here. From 026ced367f05d01c53a8b3221fe272be44b31299 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 12 Jul 2023 19:52:29 +0300 Subject: [PATCH 02/22] Added conditional imports for ExceptionGroup --- tests/middleware/test_base.py | 4 ++++ tests/middleware/test_wsgi.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index c8e72291d..9284cdb8d 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,4 +1,5 @@ import contextvars +import sys from contextlib import AsyncExitStack from typing import AsyncGenerator, Awaitable, Callable, List, Union @@ -15,6 +16,9 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index c05df3f7d..6a0257dea 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -4,6 +4,9 @@ from starlette.middleware.wsgi import WSGIMiddleware, build_environ +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + def hello_world(environ, start_response): status = "200 OK" From 67e5a4116866818879e1287583178a13584d0275 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 12 Jul 2023 21:45:35 +0300 Subject: [PATCH 03/22] Fixed mypy errors --- starlette/middleware/base.py | 3 ++- starlette/middleware/wsgi.py | 3 ++- starlette/testclient.py | 6 ++++-- tests/middleware/test_base.py | 2 +- tests/test_websockets.py | 3 ++- 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 170a805a7..ef5a5d4d6 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -107,7 +107,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None - send_stream, recv_stream = anyio.create_memory_object_stream() + send_stream, recv_stream = anyio.create_memory_object_stream[ + typing.MutableMapping[str, typing.Any]]() async def receive_or_disconnect() -> Message: if response_sent.is_set(): diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 9dbd06528..896c3c450 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -77,7 +77,8 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None: self.scope = scope self.status = None self.response_headers = None - self.stream_send, self.stream_receive = anyio.create_memory_object_stream( + self.stream_send, self.stream_receive = anyio.create_memory_object_stream[ + typing.MutableMapping[str, typing.Any]]( math.inf ) self.response_started = False diff --git a/starlette/testclient.py b/starlette/testclient.py index 1b4f1303f..5eff97c90 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -738,10 +738,12 @@ def reset_portal() -> None: self.portal = None self.stream_send = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) + *anyio.create_memory_object_stream[ + typing.Optional[typing.MutableMapping[str, typing.Any]]](math.inf) ) self.stream_receive = StapledObjectStream( - *anyio.create_memory_object_stream(math.inf) + *anyio.create_memory_object_stream[ + typing.MutableMapping[str, typing.Any]](math.inf) ) self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 9284cdb8d..238a9fddc 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -85,7 +85,7 @@ def test_custom_middleware(test_client_factory): with pytest.raises(ExceptionGroup) as ctx: response = client.get("/exc-stream") - exc = ctx.value + exc: Exception = ctx.value while isinstance(exc, ExceptionGroup): assert len(exc.exceptions) == 1 exc = exc.exceptions[0] diff --git a/tests/test_websockets.py b/tests/test_websockets.py index c1ec1153e..ca93a6ce5 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -1,4 +1,5 @@ import sys +from typing import Any, MutableMapping import anyio import pytest @@ -178,7 +179,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_websocket_concurrency_pattern(test_client_factory): - stream_send, stream_receive = anyio.create_memory_object_stream() + stream_send, stream_receive = anyio.create_memory_object_stream[MutableMapping[str, Any]]() async def reader(websocket): async with stream_send: From df70155ae94ac79c1baa54e52599a3855120027c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 12 Jul 2023 21:47:41 +0300 Subject: [PATCH 04/22] Fixed black errors --- starlette/middleware/base.py | 3 ++- starlette/middleware/wsgi.py | 5 ++--- starlette/testclient.py | 6 ++++-- tests/test_websockets.py | 4 +++- 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index ef5a5d4d6..9c12ddec1 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -108,7 +108,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None send_stream, recv_stream = anyio.create_memory_object_stream[ - typing.MutableMapping[str, typing.Any]]() + typing.MutableMapping[str, typing.Any] + ]() async def receive_or_disconnect() -> Message: if response_sent.is_set(): diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 896c3c450..1263a3f6f 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -78,9 +78,8 @@ def __init__(self, app: typing.Callable, scope: Scope) -> None: self.status = None self.response_headers = None self.stream_send, self.stream_receive = anyio.create_memory_object_stream[ - typing.MutableMapping[str, typing.Any]]( - math.inf - ) + typing.MutableMapping[str, typing.Any] + ](math.inf) self.response_started = False self.exc_info: typing.Any = None diff --git a/starlette/testclient.py b/starlette/testclient.py index 5eff97c90..8f3980687 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -739,11 +739,13 @@ def reset_portal() -> None: self.stream_send = StapledObjectStream( *anyio.create_memory_object_stream[ - typing.Optional[typing.MutableMapping[str, typing.Any]]](math.inf) + typing.Optional[typing.MutableMapping[str, typing.Any]] + ](math.inf) ) self.stream_receive = StapledObjectStream( *anyio.create_memory_object_stream[ - typing.MutableMapping[str, typing.Any]](math.inf) + typing.MutableMapping[str, typing.Any] + ](math.inf) ) self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index ca93a6ce5..01160b6ee 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -179,7 +179,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_websocket_concurrency_pattern(test_client_factory): - stream_send, stream_receive = anyio.create_memory_object_stream[MutableMapping[str, Any]]() + stream_send, stream_receive = anyio.create_memory_object_stream[ + MutableMapping[str, Any] + ]() async def reader(websocket): async with stream_send: From e42480f67eec414be2f95c87e787254d8dbe9ddf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 12:54:43 +0300 Subject: [PATCH 05/22] Fixed last failure --- tests/middleware/test_base.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 238a9fddc..e81dcb628 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -1,10 +1,10 @@ import contextvars -import sys from contextlib import AsyncExitStack from typing import AsyncGenerator, Awaitable, Callable, List, Union import anyio import pytest +from exceptiongroup import ExceptionGroup, catch from starlette.applications import Starlette from starlette.background import BackgroundTask @@ -16,9 +16,6 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send -if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup - class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): @@ -219,12 +216,23 @@ async def homepage(request): ctxvar.set("set by endpoint") return PlainTextResponse("Homepage") + def handle_assertion_error(eg): + if middleware_cls is CustomMiddlewareUsingBaseHTTPMiddleware: + pytest.xfail( + "BaseHTTPMiddleware creates a TaskGroup which copies the context" + "and erases any changes to it made within the TaskGroup" + ) + + raise + app = Starlette( middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] ) client = test_client_factory(app) - response = client.get("/") + with catch({AssertionError: handle_assertion_error}): + response = client.get("/") + assert response.status_code == 200, response.content From b9f27e4fc6d30e1514ce5df4dce12157211bc122 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 12:57:13 +0300 Subject: [PATCH 06/22] Added missing test dependency --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index bdfc2588e..2f04fdb21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ types-PyYAML==6.0.12.10 types-dataclasses==0.6.6 pytest==7.3.1 trio==0.22.1 +exceptiongroup==1.1.2 anyio@git+/~https://github.com/agronholm/anyio.git@detect-asyncio-native-cancel # Documentation From 9568e8e7921668a615f965873b8d4328f59fc6cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 13:00:47 +0300 Subject: [PATCH 07/22] Exclude "raise" in handle_assertion_error() from coverage --- tests/middleware/test_base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index e81dcb628..3af34c333 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -223,7 +223,7 @@ def handle_assertion_error(eg): "and erases any changes to it made within the TaskGroup" ) - raise + raise # pragma: no cover app = Starlette( middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] From aa3874c12346c50304608ddad2b36100ef276959 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 13:03:36 +0300 Subject: [PATCH 08/22] Exclude conditional import from coverage --- tests/middleware/test_wsgi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 6a0257dea..ad3975403 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -4,7 +4,7 @@ from starlette.middleware.wsgi import WSGIMiddleware, build_environ -if sys.version_info < (3, 11): +if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import ExceptionGroup From 6c71321593c4195cb71ba5ab99448401649954b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 17:20:40 +0300 Subject: [PATCH 09/22] Simplified the changes --- tests/exc_converter.py | 19 +++++++++++++++++++ tests/middleware/test_base.py | 33 +++++++-------------------------- tests/middleware/test_wsgi.py | 9 ++------- 3 files changed, 28 insertions(+), 33 deletions(-) create mode 100644 tests/exc_converter.py diff --git a/tests/exc_converter.py b/tests/exc_converter.py new file mode 100644 index 000000000..5cd8490d7 --- /dev/null +++ b/tests/exc_converter.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +import sys +from collections.abc import Generator +from contextlib import contextmanager + +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + + +@contextmanager +def convert_excgroups() -> Generator[None, None, None]: + try: + yield + except BaseException as exc: + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] + + raise exc diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 3af34c333..a95bac1e6 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -4,7 +4,6 @@ import anyio import pytest -from exceptiongroup import ExceptionGroup, catch from starlette.applications import Starlette from starlette.background import BackgroundTask @@ -15,6 +14,7 @@ from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send +from ..exc_converter import convert_excgroups class CustomMiddleware(BaseHTTPMiddleware): @@ -75,26 +75,16 @@ def test_custom_middleware(test_client_factory): response = client.get("/") assert response.headers["Custom-Header"] == "Example" - with pytest.raises(ExceptionGroup) as ctx: + with pytest.raises(Exception) as ctx, convert_excgroups(): response = client.get("/exc") - assert len(ctx.value.exceptions) == 1 - assert str(ctx.value.exceptions[0]) == "Exc" + assert str(ctx.value) == "Exc" - with pytest.raises(ExceptionGroup) as ctx: + with pytest.raises(Exception) as ctx, convert_excgroups(): response = client.get("/exc-stream") - exc: Exception = ctx.value - while isinstance(exc, ExceptionGroup): - assert len(exc.exceptions) == 1 - exc = exc.exceptions[0] - assert str(exc) == "Faulty Stream" + assert str(ctx.value) == "Faulty Stream" - with pytest.raises(ExceptionGroup) as ctx: + with pytest.raises(RuntimeError), convert_excgroups(): response = client.get("/no-response") - exc = ctx.value - while isinstance(exc, ExceptionGroup): - assert len(exc.exceptions) == 1 - exc = exc.exceptions[0] - assert isinstance(exc, RuntimeError) with client.websocket_connect("/ws") as session: text = session.receive_text() @@ -216,21 +206,12 @@ async def homepage(request): ctxvar.set("set by endpoint") return PlainTextResponse("Homepage") - def handle_assertion_error(eg): - if middleware_cls is CustomMiddlewareUsingBaseHTTPMiddleware: - pytest.xfail( - "BaseHTTPMiddleware creates a TaskGroup which copies the context" - "and erases any changes to it made within the TaskGroup" - ) - - raise # pragma: no cover - app = Starlette( middleware=[Middleware(middleware_cls)], routes=[Route("/", homepage)] ) client = test_client_factory(app) - with catch({AssertionError: handle_assertion_error}): + with convert_excgroups(): response = client.get("/") assert response.status_code == 200, response.content diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index ad3975403..3741d8757 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -3,9 +3,7 @@ import pytest from starlette.middleware.wsgi import WSGIMiddleware, build_environ - -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import ExceptionGroup +from ..exc_converter import convert_excgroups def hello_world(environ, start_response): @@ -69,12 +67,9 @@ def test_wsgi_exception(test_client_factory): # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) - with pytest.raises(ExceptionGroup) as exc: + with pytest.raises(RuntimeError), convert_excgroups(): client.get("/") - assert len(exc.value.exceptions) == 1 - assert isinstance(exc.value.exceptions[0], RuntimeError) - def test_wsgi_exc_info(test_client_factory): # Note that we're testing the WSGI app directly here. From 7225f96f004dff5f2bae2a15d5e8d005ad2864b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 17:28:28 +0300 Subject: [PATCH 10/22] Fixed linting errors --- tests/middleware/test_base.py | 1 + tests/middleware/test_wsgi.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index a95bac1e6..29276b43b 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -14,6 +14,7 @@ from starlette.routing import Route, WebSocketRoute from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send + from ..exc_converter import convert_excgroups diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index 3741d8757..a59925891 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -3,6 +3,7 @@ import pytest from starlette.middleware.wsgi import WSGIMiddleware, build_environ + from ..exc_converter import convert_excgroups From 1e9096d99cdaa1ac400a6676d2b53f577527eb9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 18:21:49 +0300 Subject: [PATCH 11/22] Dropped exceptiongroup as a test dependency --- requirements.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2f04fdb21..bdfc2588e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,6 @@ types-PyYAML==6.0.12.10 types-dataclasses==0.6.6 pytest==7.3.1 trio==0.22.1 -exceptiongroup==1.1.2 anyio@git+/~https://github.com/agronholm/anyio.git@detect-asyncio-native-cancel # Documentation From e1becdb770ce4371296070aaade840c279ab7ffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 18:45:44 +0300 Subject: [PATCH 12/22] Restored compatibility with AnyIO 3.x --- starlette/middleware/base.py | 7 +++++-- starlette/middleware/wsgi.py | 10 +++++++--- starlette/testclient.py | 17 +++++++---------- tests/test_websockets.py | 7 ++++--- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 9c12ddec1..cbbd1eae8 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,6 +1,7 @@ import typing import anyio +from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from starlette.background import BackgroundTask from starlette.requests import ClientDisconnect, Request @@ -107,9 +108,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None - send_stream, recv_stream = anyio.create_memory_object_stream[ + send_stream: MemoryObjectSendStream[typing.MutableMapping[str, typing.Any]] + recv_stream: MemoryObjectReceiveStream[ typing.MutableMapping[str, typing.Any] - ]() + ] + send_stream, recv_stream = anyio.create_memory_object_stream() async def receive_or_disconnect() -> Message: if response_sent.is_set(): diff --git a/starlette/middleware/wsgi.py b/starlette/middleware/wsgi.py index 1263a3f6f..d4a117cac 100644 --- a/starlette/middleware/wsgi.py +++ b/starlette/middleware/wsgi.py @@ -5,6 +5,7 @@ import warnings import anyio +from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette.types import Receive, Scope, Send @@ -72,14 +73,17 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: class WSGIResponder: + stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + def __init__(self, app: typing.Callable, scope: Scope) -> None: self.app = app self.scope = scope self.status = None self.response_headers = None - self.stream_send, self.stream_receive = anyio.create_memory_object_stream[ - typing.MutableMapping[str, typing.Any] - ](math.inf) + self.stream_send, self.stream_receive = anyio.create_memory_object_stream( + math.inf + ) self.response_started = False self.exc_info: typing.Any = None diff --git a/starlette/testclient.py b/starlette/testclient.py index 8f3980687..7da1f6984 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -12,6 +12,7 @@ import anyio import anyio.from_thread +from anyio.abc import ObjectStream from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable @@ -737,16 +738,12 @@ def __enter__(self) -> "TestClient": def reset_portal() -> None: self.portal = None - self.stream_send = StapledObjectStream( - *anyio.create_memory_object_stream[ - typing.Optional[typing.MutableMapping[str, typing.Any]] - ](math.inf) - ) - self.stream_receive = StapledObjectStream( - *anyio.create_memory_object_stream[ - typing.MutableMapping[str, typing.Any] - ](math.inf) - ) + self.stream_send: ObjectStream[ + typing.Optional[typing.MutableMapping[str, typing.Any]] + ] = StapledObjectStream(*anyio.create_memory_object_stream(math.inf)) + self.stream_receive: ObjectStream[ + typing.MutableMapping[str, typing.Any] + ] = StapledObjectStream(*anyio.create_memory_object_stream(math.inf)) self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 01160b6ee..71bccd455 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -3,6 +3,7 @@ import anyio import pytest +from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette import status from starlette.types import Receive, Scope, Send @@ -179,9 +180,9 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: def test_websocket_concurrency_pattern(test_client_factory): - stream_send, stream_receive = anyio.create_memory_object_stream[ - MutableMapping[str, Any] - ]() + stream_send: ObjectSendStream[MutableMapping[str, Any]] + stream_receive: ObjectReceiveStream[MutableMapping[str, Any]] + stream_send, stream_receive = anyio.create_memory_object_stream() async def reader(websocket): async with stream_send: From ac67df8e8116b54d99be9c9ddcd7ecab3a6abbf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 18:53:32 +0300 Subject: [PATCH 13/22] Fixed mypy errors --- starlette/testclient.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 7da1f6984..57fbdeeae 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -12,7 +12,7 @@ import anyio import anyio.from_thread -from anyio.abc import ObjectStream +from anyio.abc import ObjectStream, ObjectSendStream, ObjectReceiveStream from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable @@ -738,12 +738,14 @@ def __enter__(self) -> "TestClient": def reset_portal() -> None: self.portal = None - self.stream_send: ObjectStream[ - typing.Optional[typing.MutableMapping[str, typing.Any]] - ] = StapledObjectStream(*anyio.create_memory_object_stream(math.inf)) - self.stream_receive: ObjectStream[ - typing.MutableMapping[str, typing.Any] - ] = StapledObjectStream(*anyio.create_memory_object_stream(math.inf)) + send1: ObjectSendStream[typing.Optional[typing.MutableMapping[str, typing.Any]]] + receive1: ObjectReceiveStream[typing.Optional[typing.MutableMapping[str, typing.Any]]] + send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] + send1, receive1 = anyio.create_memory_object_stream(math.inf) + send2, receive2 = anyio.create_memory_object_stream(math.inf) + self.stream_send = StapledObjectStream(send1, receive1) + self.stream_receive = StapledObjectStream(send2, receive2) self.task = portal.start_task_soon(self.lifespan) portal.call(self.wait_startup) From 86fdcbf54f8f833c4e903df742c96bf154a6cd55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Thu, 13 Jul 2023 18:54:55 +0300 Subject: [PATCH 14/22] Fixed linting errors --- starlette/testclient.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 57fbdeeae..9a7043c8a 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -12,7 +12,7 @@ import anyio import anyio.from_thread -from anyio.abc import ObjectStream, ObjectSendStream, ObjectReceiveStream +from anyio.abc import ObjectReceiveStream, ObjectSendStream from anyio.streams.stapled import StapledObjectStream from starlette._utils import is_async_callable @@ -738,8 +738,12 @@ def __enter__(self) -> "TestClient": def reset_portal() -> None: self.portal = None - send1: ObjectSendStream[typing.Optional[typing.MutableMapping[str, typing.Any]]] - receive1: ObjectReceiveStream[typing.Optional[typing.MutableMapping[str, typing.Any]]] + send1: ObjectSendStream[ + typing.Optional[typing.MutableMapping[str, typing.Any]] + ] + receive1: ObjectReceiveStream[ + typing.Optional[typing.MutableMapping[str, typing.Any]] + ] send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]] receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] send1, receive1 = anyio.create_memory_object_stream(math.inf) From 9773d06ea2beb299a5a8c756518bd60104adf8aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Fri, 14 Jul 2023 01:29:40 +0300 Subject: [PATCH 15/22] Use general ABCs for object stream type annotations for consistency --- starlette/middleware/base.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index cbbd1eae8..3650b5ff4 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,7 +1,7 @@ import typing import anyio -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from anyio.abc import ObjectReceiveStream, ObjectSendStream from starlette.background import BackgroundTask from starlette.requests import ClientDisconnect, Request @@ -108,10 +108,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def call_next(request: Request) -> Response: app_exc: typing.Optional[Exception] = None - send_stream: MemoryObjectSendStream[typing.MutableMapping[str, typing.Any]] - recv_stream: MemoryObjectReceiveStream[ - typing.MutableMapping[str, typing.Any] - ] + send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]] + recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]] send_stream, recv_stream = anyio.create_memory_object_stream() async def receive_or_disconnect() -> Message: From 053b96ba74db365a86acbd3fffdbe88d5195ed0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 22 Jul 2023 14:37:05 +0200 Subject: [PATCH 16/22] Moved convert_excgroups() from tests to the main code base --- starlette/_exception_handler.py | 12 ++++++++++++ starlette/middleware/base.py | 10 ++++++---- tests/exc_converter.py | 19 ------------------- tests/middleware/test_base.py | 12 ++++-------- tests/middleware/test_wsgi.py | 3 +-- 5 files changed, 23 insertions(+), 33 deletions(-) delete mode 100644 tests/exc_converter.py diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 8a9beb3b2..365ea5fa2 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -1,4 +1,5 @@ import typing +from contextlib import contextmanager from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool @@ -74,3 +75,14 @@ async def sender(message: Message) -> None: await run_in_threadpool(handler, conn, exc) return wrapped_app + + +@contextmanager +def convert_excgroups() -> typing.Generator[None, None, None]: + try: + yield + except BaseException as exc: + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] + + raise exc diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 3650b5ff4..c94823d89 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -3,6 +3,7 @@ import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream +from starlette._exception_handler import convert_excgroups from starlette.background import BackgroundTask from starlette.requests import ClientDisconnect, Request from starlette.responses import ContentStream, Response, StreamingResponse @@ -185,10 +186,11 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response.raw_headers = message["headers"] return response - async with anyio.create_task_group() as task_group: - response = await self.dispatch_func(request, call_next) - await response(scope, wrapped_receive, send) - response_sent.set() + with convert_excgroups(): + async with anyio.create_task_group() as task_group: + response = await self.dispatch_func(request, call_next) + await response(scope, wrapped_receive, send) + response_sent.set() async def dispatch( self, request: Request, call_next: RequestResponseEndpoint diff --git a/tests/exc_converter.py b/tests/exc_converter.py deleted file mode 100644 index 5cd8490d7..000000000 --- a/tests/exc_converter.py +++ /dev/null @@ -1,19 +0,0 @@ -from __future__ import annotations - -import sys -from collections.abc import Generator -from contextlib import contextmanager - -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import BaseExceptionGroup - - -@contextmanager -def convert_excgroups() -> Generator[None, None, None]: - try: - yield - except BaseException as exc: - while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: - exc = exc.exceptions[0] - - raise exc diff --git a/tests/middleware/test_base.py b/tests/middleware/test_base.py index 29276b43b..cf4780cce 100644 --- a/tests/middleware/test_base.py +++ b/tests/middleware/test_base.py @@ -15,8 +15,6 @@ from starlette.testclient import TestClient from starlette.types import ASGIApp, Message, Receive, Scope, Send -from ..exc_converter import convert_excgroups - class CustomMiddleware(BaseHTTPMiddleware): async def dispatch(self, request, call_next): @@ -76,15 +74,15 @@ def test_custom_middleware(test_client_factory): response = client.get("/") assert response.headers["Custom-Header"] == "Example" - with pytest.raises(Exception) as ctx, convert_excgroups(): + with pytest.raises(Exception) as ctx: response = client.get("/exc") assert str(ctx.value) == "Exc" - with pytest.raises(Exception) as ctx, convert_excgroups(): + with pytest.raises(Exception) as ctx: response = client.get("/exc-stream") assert str(ctx.value) == "Faulty Stream" - with pytest.raises(RuntimeError), convert_excgroups(): + with pytest.raises(RuntimeError): response = client.get("/no-response") with client.websocket_connect("/ws") as session: @@ -212,9 +210,7 @@ async def homepage(request): ) client = test_client_factory(app) - with convert_excgroups(): - response = client.get("/") - + response = client.get("/") assert response.status_code == 200, response.content diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index a59925891..a0abdb6af 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -2,10 +2,9 @@ import pytest +from starlette._exception_handler import convert_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ -from ..exc_converter import convert_excgroups - def hello_world(environ, start_response): status = "200 OK" From 68cc1a83c3d25eb67d30c8ad72de255c5133ffd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 22 Jul 2023 14:41:13 +0200 Subject: [PATCH 17/22] Updated anyio dependency to point to master --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 86af56c04..65e240832 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ types-PyYAML==6.0.12.10 types-dataclasses==0.6.6 pytest==7.4.0 trio==0.22.1 -anyio@git+/~https://github.com/agronholm/anyio.git@detect-asyncio-native-cancel +anyio@git+/~https://github.com/agronholm/anyio.git # Documentation mkdocs==1.4.3 From ef5f16a8d935489d8df7d6a43aba800caf5e1857 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 22 Jul 2023 14:44:44 +0200 Subject: [PATCH 18/22] Fixed linting error --- starlette/_exception_handler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 365ea5fa2..0001e037f 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -1,3 +1,4 @@ +import sys import typing from contextlib import contextmanager @@ -9,6 +10,9 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + Handler = typing.Callable[..., typing.Any] ExceptionHandlers = typing.Dict[typing.Any, Handler] StatusHandlers = typing.Dict[int, Handler] From cb6008b3a5382a838bdaf917cbaf0a5639759264 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 22 Jul 2023 14:50:41 +0200 Subject: [PATCH 19/22] Ignore coverage for a conditional import --- starlette/_exception_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index 0001e037f..fc1257c05 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -10,7 +10,7 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket -if sys.version_info < (3, 11): +if sys.version_info < (3, 11): # pragma: no cover from exceptiongroup import BaseExceptionGroup Handler = typing.Callable[..., typing.Any] From fcb90568bfde76d664e2d6c28a2402643148e7d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 22 Jul 2023 14:59:48 +0200 Subject: [PATCH 20/22] Moved the exception group converter to middleware/base.py --- starlette/_exception_handler.py | 11 ----------- starlette/middleware/base.py | 15 +++++++++++++-- tests/middleware/test_wsgi.py | 6 ++++-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index fc1257c05..ad4109ff4 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -79,14 +79,3 @@ async def sender(message: Message) -> None: await run_in_threadpool(handler, conn, exc) return wrapped_app - - -@contextmanager -def convert_excgroups() -> typing.Generator[None, None, None]: - try: - yield - except BaseException as exc: - while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: - exc = exc.exceptions[0] - - raise exc diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index c94823d89..d370f9463 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,9 +1,9 @@ import typing +from contextlib import contextmanager import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream -from starlette._exception_handler import convert_excgroups from starlette.background import BackgroundTask from starlette.requests import ClientDisconnect, Request from starlette.responses import ContentStream, Response, StreamingResponse @@ -16,6 +16,17 @@ T = typing.TypeVar("T") +@contextmanager +def _convert_excgroups() -> typing.Generator[None, None, None]: + try: + yield + except BaseException as exc: + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] + + raise exc + + class _CachedRequest(Request): """ If the user calls Request.body() from their dispatch function @@ -186,7 +197,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response.raw_headers = message["headers"] return response - with convert_excgroups(): + with _convert_excgroups(): async with anyio.create_task_group() as task_group: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index a0abdb6af..c05df3f7d 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -2,7 +2,6 @@ import pytest -from starlette._exception_handler import convert_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ @@ -67,9 +66,12 @@ def test_wsgi_exception(test_client_factory): # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) - with pytest.raises(RuntimeError), convert_excgroups(): + with pytest.raises(ExceptionGroup) as exc: client.get("/") + assert len(exc.value.exceptions) == 1 + assert isinstance(exc.value.exceptions[0], RuntimeError) + def test_wsgi_exc_info(test_client_factory): # Note that we're testing the WSGI app directly here. From c1454cf6003eed3be0c0de4b42a5c924c4d656ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 22 Jul 2023 15:08:05 +0200 Subject: [PATCH 21/22] Removed unnecessary imports --- starlette/_exception_handler.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/starlette/_exception_handler.py b/starlette/_exception_handler.py index ad4109ff4..8a9beb3b2 100644 --- a/starlette/_exception_handler.py +++ b/starlette/_exception_handler.py @@ -1,6 +1,4 @@ -import sys import typing -from contextlib import contextmanager from starlette._utils import is_async_callable from starlette.concurrency import run_in_threadpool @@ -10,9 +8,6 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send from starlette.websockets import WebSocket -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import BaseExceptionGroup - Handler = typing.Callable[..., typing.Any] ExceptionHandlers = typing.Dict[typing.Any, Handler] StatusHandlers = typing.Dict[int, Handler] From 07ceb21a027c6fd3e6dc6ff26b72f9338aaf5015 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Sat, 22 Jul 2023 15:12:41 +0200 Subject: [PATCH 22/22] Fixed import errors --- starlette/middleware/base.py | 4 ++++ tests/middleware/test_wsgi.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index d370f9463..ee99ee6cb 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,3 +1,4 @@ +import sys import typing from contextlib import contextmanager @@ -9,6 +10,9 @@ from starlette.responses import ContentStream, Response, StreamingResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import BaseExceptionGroup + RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index c05df3f7d..ad3975403 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -4,6 +4,9 @@ from starlette.middleware.wsgi import WSGIMiddleware, build_environ +if sys.version_info < (3, 11): # pragma: no cover + from exceptiongroup import ExceptionGroup + def hello_world(environ, start_response): status = "200 OK"