Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgraded to AnyIO 4.0 #2211

Merged
merged 25 commits into from
Jul 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
5c41699
Upgraded to AnyIO 4.0 and dropped Python 3.8 support
agronholm Jul 12, 2023
026ced3
Added conditional imports for ExceptionGroup
agronholm Jul 12, 2023
67e5a41
Fixed mypy errors
agronholm Jul 12, 2023
df70155
Fixed black errors
agronholm Jul 12, 2023
e42480f
Fixed last failure
agronholm Jul 13, 2023
b9f27e4
Added missing test dependency
agronholm Jul 13, 2023
9568e8e
Exclude "raise" in handle_assertion_error() from coverage
agronholm Jul 13, 2023
aa3874c
Exclude conditional import from coverage
agronholm Jul 13, 2023
6c71321
Simplified the changes
agronholm Jul 13, 2023
7225f96
Fixed linting errors
agronholm Jul 13, 2023
1e9096d
Dropped exceptiongroup as a test dependency
agronholm Jul 13, 2023
e1becdb
Restored compatibility with AnyIO 3.x
agronholm Jul 13, 2023
ac67df8
Fixed mypy errors
agronholm Jul 13, 2023
86fdcbf
Fixed linting errors
agronholm Jul 13, 2023
efe6cf5
Merge branch 'master' into anyio4
agronholm Jul 13, 2023
efef252
Merge branch 'master' into anyio4
agronholm Jul 13, 2023
9773d06
Use general ABCs for object stream type annotations for consistency
agronholm Jul 13, 2023
053b96b
Moved convert_excgroups() from tests to the main code base
agronholm Jul 22, 2023
68cc1a8
Updated anyio dependency to point to master
agronholm Jul 22, 2023
ef5f16a
Fixed linting error
agronholm Jul 22, 2023
cb6008b
Ignore coverage for a conditional import
agronholm Jul 22, 2023
fcb9056
Moved the exception group converter to middleware/base.py
agronholm Jul 22, 2023
c1454cf
Removed unnecessary imports
agronholm Jul 22, 2023
07ceb21
Fixed import errors
agronholm Jul 22, 2023
2dc718e
Merge branch 'master' into anyio4
Kludex Jul 23, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ types-contextvars==2.4.7.2
types-PyYAML==6.0.12.10
types-dataclasses==0.6.6
pytest==7.4.0
trio==0.21.0
trio==0.22.1
anyio@git+/~https://github.com/agronholm/anyio.git

# Documentation
mkdocs==1.4.3
Expand Down
28 changes: 24 additions & 4 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
import sys
import typing
from contextlib import contextmanager

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette.background import BackgroundTask
from starlette.requests import ClientDisconnect, Request
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import BaseExceptionGroup

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
DispatchFunction = typing.Callable[
[Request, RequestResponseEndpoint], typing.Awaitable[Response]
]
T = typing.TypeVar("T")


@contextmanager
def _convert_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]

raise exc


class _CachedRequest(Request):
"""
If the user calls Request.body() from their dispatch function
Expand Down Expand Up @@ -107,6 +124,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

async def call_next(request: Request) -> Response:
app_exc: typing.Optional[Exception] = None
send_stream: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
recv_stream: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send_stream, recv_stream = anyio.create_memory_object_stream()

async def receive_or_disconnect() -> Message:
Expand Down Expand Up @@ -182,10 +201,11 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
response.raw_headers = message["headers"]
return response

async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
with _convert_excgroups():
async with anyio.create_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()

async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
Expand Down
4 changes: 4 additions & 0 deletions starlette/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings

import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -72,6 +73,9 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:


class WSGIResponder:
stream_send: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
stream_receive: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]

def __init__(self, app: typing.Callable, scope: Scope) -> None:
self.app = app
self.scope = scope
Expand Down
19 changes: 13 additions & 6 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import anyio
import anyio.from_thread
from anyio.abc import ObjectReceiveStream, ObjectSendStream
from anyio.streams.stapled import StapledObjectStream

from starlette._utils import is_async_callable
Expand Down Expand Up @@ -737,12 +738,18 @@ def __enter__(self) -> "TestClient":
def reset_portal() -> None:
self.portal = None

self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
send1: ObjectSendStream[
typing.Optional[typing.MutableMapping[str, typing.Any]]
]
receive1: ObjectReceiveStream[
typing.Optional[typing.MutableMapping[str, typing.Any]]
]
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
send1, receive1 = anyio.create_memory_object_stream(math.inf)
send2, receive2 = anyio.create_memory_object_stream(math.inf)
self.stream_send = StapledObjectStream(send1, receive1)
self.stream_receive = StapledObjectStream(send2, receive2)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)

Expand Down
8 changes: 7 additions & 1 deletion tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

from starlette.middleware.wsgi import WSGIMiddleware, build_environ

if sys.version_info < (3, 11): # pragma: no cover
from exceptiongroup import ExceptionGroup


def hello_world(environ, start_response):
status = "200 OK"
Expand Down Expand Up @@ -66,9 +69,12 @@ def test_wsgi_exception(test_client_factory):
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
with pytest.raises(RuntimeError):
with pytest.raises(ExceptionGroup) as exc:
client.get("/")

assert len(exc.value.exceptions) == 1
assert isinstance(exc.value.exceptions[0], RuntimeError)


def test_wsgi_exc_info(test_client_factory):
# Note that we're testing the WSGI app directly here.
Expand Down
4 changes: 4 additions & 0 deletions tests/test_websockets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys
from typing import Any, MutableMapping

import anyio
import pytest
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette import status
from starlette.types import Receive, Scope, Send
Expand Down Expand Up @@ -178,6 +180,8 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:


def test_websocket_concurrency_pattern(test_client_factory):
stream_send: ObjectSendStream[MutableMapping[str, Any]]
stream_receive: ObjectReceiveStream[MutableMapping[str, Any]]
stream_send, stream_receive = anyio.create_memory_object_stream()

async def reader(websocket):
Expand Down