From 03832bdf1b7e2e7459d9e38a0c11b822118fe168 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 30 Jun 2021 19:21:00 +0100 Subject: [PATCH 1/7] use an async context manager factory for lifespan --- setup.py | 1 + starlette/applications.py | 2 +- starlette/routing.py | 105 ++++++++++++++++++++++++++++++-------- 3 files changed, 86 insertions(+), 22 deletions(-) diff --git a/setup.py b/setup.py index ac6479746..dbed23c07 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ def get_long_description(): install_requires=[ "anyio>=3.0.0,<4", "typing_extensions; python_version < '3.8'", + "contextlib2; python_version < '3.10'", ], extras_require={ "full": [ diff --git a/starlette/applications.py b/starlette/applications.py index 34c3e38bd..ea52ee70e 100644 --- a/starlette/applications.py +++ b/starlette/applications.py @@ -46,7 +46,7 @@ def __init__( ] = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, - lifespan: typing.Callable[["Starlette"], typing.AsyncGenerator] = None, + lifespan: typing.Callable[["Starlette"], typing.AsyncContextManager] = None, ) -> None: # The lifespan context function is a newer style that replaces # on_startup / on_shutdown handlers. Use one or the other, not both. diff --git a/starlette/routing.py b/starlette/routing.py index cef1ef484..096d3a436 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -1,9 +1,12 @@ import asyncio +import contextlib import functools import inspect import re +import sys import traceback import typing +import warnings from enum import Enum from starlette.concurrency import run_in_threadpool @@ -15,6 +18,16 @@ from starlette.types import ASGIApp, Receive, Scope, Send from starlette.websockets import WebSocket, WebSocketClose +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager +else: + from contextlib2 import asynccontextmanager + +if sys.version_info >= (3, 10): + from contextlib import aclosing +else: + from contextlib2 import aclosing + class NoMatchFound(Exception): """ @@ -470,6 +483,54 @@ def __eq__(self, other: typing.Any) -> bool: ) +def _wrap_agen_lifespan_context( + lifespan_context: typing.Callable[[typing.Any], typing.AsyncGenerator] +) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: + @functools.wraps(lifespan_context) + @asynccontextmanager + async def agen_wrapper( + app: typing.Any, + ) -> typing.AsyncGenerator[None, None]: + async with aclosing(lifespan_context(app)) as agen: # type: ignore + async for _ in agen: + yield + + return agen_wrapper + + +def _wrap_gen_lifespan_context( + lifespan_context: typing.Callable[[typing.Any], typing.Generator] +) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: + @functools.wraps(lifespan_context) + @asynccontextmanager + async def gen_wrapper( + app: typing.Any, + ) -> typing.AsyncGenerator[None, None]: + with contextlib.closing(lifespan_context(app)) as gen: + for _ in gen: + yield + + return gen_wrapper + + +def _wrap_lifespan_context( + lifespan_context: typing.Union[ + typing.Callable[[typing.Any], typing.AsyncGenerator], + typing.Callable[[typing.Any], typing.Generator], + typing.Callable[[typing.Any], typing.AsyncContextManager], + ] +) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: + if inspect.isasyncgenfunction(lifespan_context): + warnings.warn("lifespan must be an AsyncContextManager factory") + return _wrap_agen_lifespan_context(lifespan_context) # type: ignore[arg-type] + + if inspect.isgeneratorfunction(lifespan_context): + warnings.warn("lifespan must be an AsyncContextManager factory") + return _wrap_gen_lifespan_context(lifespan_context) # type: ignore[arg-type] + + return lifespan_context # type: ignore + + class Router: def __init__( self, @@ -478,7 +539,7 @@ def __init__( default: ASGIApp = None, on_startup: typing.Sequence[typing.Callable] = None, on_shutdown: typing.Sequence[typing.Callable] = None, - lifespan: typing.Callable[[typing.Any], typing.AsyncGenerator] = None, + lifespan: typing.Callable[[typing.Any], typing.AsyncContextManager] = None, ) -> None: self.routes = [] if routes is None else list(routes) self.redirect_slashes = redirect_slashes @@ -486,12 +547,21 @@ def __init__( self.on_startup = [] if on_startup is None else list(on_startup) self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) + @asynccontextmanager async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator: await self.startup() - yield - await self.shutdown() - - self.lifespan_context = default_lifespan if lifespan is None else lifespan + try: + yield + finally: + await self.shutdown() + + self.lifespan_context: typing.Callable[ + [typing.Any], typing.AsyncContextManager + ] = ( + default_lifespan # type: ignore + if lifespan is None + else _wrap_lifespan_context(lifespan) + ) async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": @@ -541,25 +611,18 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: Handle ASGI lifespan messages, which allows us to manage application startup and shutdown events. """ - first = True + started = False app = scope.get("app") - await receive() try: - if inspect.isasyncgenfunction(self.lifespan_context): - async for item in self.lifespan_context(app): - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() - else: - for item in self.lifespan_context(app): # type: ignore - assert first, "Lifespan context yielded multiple times." - first = False - await send({"type": "lifespan.startup.complete"}) - await receive() + async with self.lifespan_context(app): + await send({"type": "lifespan.startup.complete"}) + started = True + await receive() except BaseException: - if first: - exc_text = traceback.format_exc() + exc_text = traceback.format_exc() + if started: + await send({"type": "lifespan.shutdown.failed", "message": exc_text}) + else: await send({"type": "lifespan.startup.failed", "message": exc_text}) raise else: From 13f8fe103b9ee47ad2e802359ba3adbe1d81cb0c Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 3 Jul 2021 10:36:29 +0100 Subject: [PATCH 2/7] simplify asynccontextmanager upgrading --- starlette/routing.py | 115 +++++++++++++++++++++---------------------- 1 file changed, 57 insertions(+), 58 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 096d3a436..449e2920a 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -5,6 +5,7 @@ import re import sys import traceback +import types import typing import warnings from enum import Enum @@ -23,11 +24,6 @@ else: from contextlib2 import asynccontextmanager -if sys.version_info >= (3, 10): - from contextlib import aclosing -else: - from contextlib2 import aclosing - class NoMatchFound(Exception): """ @@ -483,52 +479,49 @@ def __eq__(self, other: typing.Any) -> bool: ) -def _wrap_agen_lifespan_context( - lifespan_context: typing.Callable[[typing.Any], typing.AsyncGenerator] -) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: - @functools.wraps(lifespan_context) - @asynccontextmanager - async def agen_wrapper( - app: typing.Any, - ) -> typing.AsyncGenerator[None, None]: - async with aclosing(lifespan_context(app)) as agen: # type: ignore - async for _ in agen: - yield +_T = typing.TypeVar("_T") - return agen_wrapper + +class _AsyncLiftContextManager(typing.AsyncContextManager[_T]): + def __init__(self, cm: typing.ContextManager[_T]): + self._cm = cm + + async def __aenter__(self) -> _T: + return self._cm.__enter__() + + async def __aexit__( + self, + exc_type: typing.Optional[typing.Type[BaseException]], + exc_value: typing.Optional[BaseException], + traceback: typing.Optional[types.TracebackType], + ) -> typing.Optional[bool]: + return self._cm.__exit__(exc_type, exc_value, traceback) def _wrap_gen_lifespan_context( lifespan_context: typing.Callable[[typing.Any], typing.Generator] ) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: - @functools.wraps(lifespan_context) - @asynccontextmanager - async def gen_wrapper( - app: typing.Any, - ) -> typing.AsyncGenerator[None, None]: - with contextlib.closing(lifespan_context(app)) as gen: - for _ in gen: - yield - - return gen_wrapper - - -def _wrap_lifespan_context( - lifespan_context: typing.Union[ - typing.Callable[[typing.Any], typing.AsyncGenerator], - typing.Callable[[typing.Any], typing.Generator], - typing.Callable[[typing.Any], typing.AsyncContextManager], - ] -) -> typing.Callable[[typing.Any], typing.AsyncContextManager]: - if inspect.isasyncgenfunction(lifespan_context): - warnings.warn("lifespan must be an AsyncContextManager factory") - return _wrap_agen_lifespan_context(lifespan_context) # type: ignore[arg-type] + cmgr = contextlib.contextmanager(lifespan_context) + + @functools.wraps(cmgr) + def wrapper(app: typing.Any) -> _AsyncLiftContextManager: + return _AsyncLiftContextManager(cmgr(app)) + + return wrapper + - if inspect.isgeneratorfunction(lifespan_context): - warnings.warn("lifespan must be an AsyncContextManager factory") - return _wrap_gen_lifespan_context(lifespan_context) # type: ignore[arg-type] +class _DefaultLifespan: + def __init__(self, router: "Router"): + self._router = router - return lifespan_context # type: ignore + async def __aenter__(self) -> None: + await self._router.startup() + + async def __aexit__(self, *exc_info: object) -> None: + await self._router.shutdown() + + def __call__(self: _T, app: object) -> _T: + return self class Router: @@ -547,21 +540,27 @@ def __init__( self.on_startup = [] if on_startup is None else list(on_startup) self.on_shutdown = [] if on_shutdown is None else list(on_shutdown) - @asynccontextmanager - async def default_lifespan(app: typing.Any) -> typing.AsyncGenerator: - await self.startup() - try: - yield - finally: - await self.shutdown() - - self.lifespan_context: typing.Callable[ - [typing.Any], typing.AsyncContextManager - ] = ( - default_lifespan # type: ignore - if lifespan is None - else _wrap_lifespan_context(lifespan) - ) + if lifespan is None: + self.lifespan_context: typing.Callable[ + [typing.Any], typing.AsyncContextManager + ] = _DefaultLifespan(self) + + elif inspect.isasyncgenfunction(lifespan): + warnings.warn( + "lifespan must be an AsyncContextManager factory", DeprecationWarning + ) + self.lifespan_context = asynccontextmanager( + lifespan, # type: ignore[arg-type] + ) + elif inspect.isgeneratorfunction(lifespan): + warnings.warn( + "lifespan must be an AsyncContextManager factory", DeprecationWarning + ) + self.lifespan_context = _wrap_gen_lifespan_context( + lifespan, # type: ignore[arg-type] + ) + else: + self.lifespan_context = lifespan async def not_found(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] == "websocket": From 68f96061e43e15fc757bab6205eab18367028f79 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 3 Jul 2021 11:08:03 +0100 Subject: [PATCH 3/7] make tests pass --- starlette/routing.py | 5 +++-- starlette/testclient.py | 9 +++++++- tests/test_applications.py | 42 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 51 insertions(+), 5 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 449e2920a..31590275a 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -20,9 +20,9 @@ from starlette.websockets import WebSocket, WebSocketClose if sys.version_info >= (3, 7): - from contextlib import asynccontextmanager + from contextlib import asynccontextmanager # pragma: no cover else: - from contextlib2 import asynccontextmanager + from contextlib2 import asynccontextmanager # pragma: no cover class NoMatchFound(Exception): @@ -612,6 +612,7 @@ async def lifespan(self, scope: Scope, receive: Receive, send: Send) -> None: """ started = False app = scope.get("app") + await receive() try: async with self.lifespan_context(app): await send({"type": "lifespan.startup.complete"}) diff --git a/starlette/testclient.py b/starlette/testclient.py index 33bb410d0..b5ae28b49 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -541,4 +541,11 @@ async def wait_shutdown(self) -> None: message = await self.stream_send.receive() if message is None: self.task.result() - assert message["type"] == "lifespan.shutdown.complete" + assert message["type"] in ( + "lifespan.shutdown.complete", + "lifespan.shutdown.failed", + ) + if message["type"] == "lifespan.shutdown.failed": + message = await self.stream_send.receive() + if message is None: + self.task.result() diff --git a/tests/test_applications.py b/tests/test_applications.py index 6cb490696..509edc4df 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -1,4 +1,5 @@ import os +import sys import pytest @@ -10,6 +11,11 @@ from starlette.routing import Host, Mount, Route, Router, WebSocketRoute from starlette.staticfiles import StaticFiles +if sys.version_info >= (3, 7): + from contextlib import asynccontextmanager # pragma: no cover +else: + from contextlib2 import asynccontextmanager # pragma: no cover + app = Starlette() @@ -286,7 +292,38 @@ def run_cleanup(): assert cleanup_complete -def test_app_async_lifespan(test_client_factory): +def test_app_async_cm_lifespan(test_client_factory): + startup_complete = False + cleanup_complete = False + + @asynccontextmanager + async def lifespan(app): + nonlocal startup_complete, cleanup_complete + startup_complete = True + yield + cleanup_complete = True + + app = Starlette(lifespan=lifespan) + + assert not startup_complete + assert not cleanup_complete + with test_client_factory(app): + assert startup_complete + assert not cleanup_complete + assert startup_complete + assert cleanup_complete + + +deprecated_lifespan = pytest.mark.filterwarnings( + "ignore" + ":lifespan must be an AsyncContextManager factory" + ":DeprecationWarning" + ":starlette.routing" +) + + +@deprecated_lifespan +def test_app_async_gen_lifespan(test_client_factory): startup_complete = False cleanup_complete = False @@ -307,7 +344,8 @@ async def lifespan(app): assert cleanup_complete -def test_app_sync_lifespan(test_client_factory): +@deprecated_lifespan +def test_app_sync_gen_lifespan(test_client_factory): startup_complete = False cleanup_complete = False From e8c3afb95d8c9f6127a060988b591536ae17d4d7 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 3 Jul 2021 11:29:47 +0100 Subject: [PATCH 4/7] get last bit of coverage --- starlette/testclient.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index b5ae28b49..83822d43b 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -523,29 +523,34 @@ async def lifespan(self) -> None: async def wait_startup(self) -> None: await self.stream_receive.send({"type": "lifespan.startup"}) - message = await self.stream_send.receive() - if message is None: - self.task.result() + + async def receive() -> typing.Any: + message = await self.stream_send.receive() + if message is None: + self.task.result() + return message + + message = await receive() assert message["type"] in ( "lifespan.startup.complete", "lifespan.startup.failed", ) if message["type"] == "lifespan.startup.failed": + await receive() + + async def wait_shutdown(self) -> None: + async def receive() -> typing.Any: message = await self.stream_send.receive() if message is None: self.task.result() + return message - async def wait_shutdown(self) -> None: async with self.stream_send: await self.stream_receive.send({"type": "lifespan.shutdown"}) - message = await self.stream_send.receive() - if message is None: - self.task.result() + message = await receive() assert message["type"] in ( "lifespan.shutdown.complete", "lifespan.shutdown.failed", ) if message["type"] == "lifespan.shutdown.failed": - message = await self.stream_send.receive() - if message is None: - self.task.result() + await receive() From 4cc26971a48dd8315cfd587ba226c4102f7da2dc Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 3 Jul 2021 11:33:32 +0100 Subject: [PATCH 5/7] narrow contextlib2 dep --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index dbed23c07..31789fe09 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def get_long_description(): install_requires=[ "anyio>=3.0.0,<4", "typing_extensions; python_version < '3.8'", - "contextlib2; python_version < '3.10'", + "contextlib2 >= 21.6.0; python_version < '3.7'", ], extras_require={ "full": [ From 1aed67d65ff2bb190c20753ece05e8e3fc7447e8 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 3 Jul 2021 17:59:13 +0100 Subject: [PATCH 6/7] use @asynccontextmanager in test_use_testclient_as_contextmanager --- tests/test_testclient.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/test_testclient.py b/tests/test_testclient.py index 57ea1c3db..8c0666789 100644 --- a/tests/test_testclient.py +++ b/tests/test_testclient.py @@ -12,10 +12,12 @@ from starlette.responses import JSONResponse from starlette.websockets import WebSocket, WebSocketDisconnect -if sys.version_info >= (3, 7): - from asyncio import current_task as asyncio_current_task # pragma: no cover -else: - asyncio_current_task = asyncio.Task.current_task # pragma: no cover +if sys.version_info >= (3, 7): # pragma: no cover + from asyncio import current_task as asyncio_current_task + from contextlib import asynccontextmanager +else: # pragma: no cover + asyncio_current_task = asyncio.Task.current_task + from contextlib2 import asynccontextmanager mock_service = Starlette() @@ -90,6 +92,7 @@ def get_identity(): shutdown_task = object() shutdown_loop = None + @asynccontextmanager async def lifespan_context(app): nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop From cedceb488573e17b860891fbd7b920a587f0fd5d Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Sat, 3 Jul 2021 18:08:05 +0100 Subject: [PATCH 7/7] improve lifespan context deprecation warnings --- starlette/routing.py | 8 ++++++-- tests/test_applications.py | 9 +++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index 31590275a..9a1a5e12d 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -547,14 +547,18 @@ def __init__( elif inspect.isasyncgenfunction(lifespan): warnings.warn( - "lifespan must be an AsyncContextManager factory", DeprecationWarning + "async generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, ) self.lifespan_context = asynccontextmanager( lifespan, # type: ignore[arg-type] ) elif inspect.isgeneratorfunction(lifespan): warnings.warn( - "lifespan must be an AsyncContextManager factory", DeprecationWarning + "generator function lifespans are deprecated, " + "use an @contextlib.asynccontextmanager function instead", + DeprecationWarning, ) self.lifespan_context = _wrap_gen_lifespan_context( lifespan, # type: ignore[arg-type] diff --git a/tests/test_applications.py b/tests/test_applications.py index 509edc4df..f5f4c7fbe 100644 --- a/tests/test_applications.py +++ b/tests/test_applications.py @@ -315,10 +315,11 @@ async def lifespan(app): deprecated_lifespan = pytest.mark.filterwarnings( - "ignore" - ":lifespan must be an AsyncContextManager factory" - ":DeprecationWarning" - ":starlette.routing" + r"ignore" + r":(async )?generator function lifespans are deprecated, use an " + r"@contextlib\.asynccontextmanager function instead" + r":DeprecationWarning" + r":starlette.routing" )