Skip to content

Commit

Permalink
ensure TestClient requests run in the same EventLoop as lifespan (#1213
Browse files Browse the repository at this point in the history
)

* ensure TestClient requests run in the same EventLoop as lifespan

* for lifespan task verification, use native task identity rather than anyio.abc.TaskInfo equality

agronholm/anyio#324

* remove redundant pragma: no cover

* it's now a loop_id not a threading.ident

* replace Protocol with plain Callable TypeAlias

* use lifespan_context to actually open a task group

trio should complain if used incorrectly here.

* assign self.portal once, schedule reset immediately after assignment

* inline apps into their tests

* make task/loop trackers nonlocals
  • Loading branch information
graingert authored Jul 3, 2021
1 parent d222b87 commit 254d0d9
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 44 deletions.
82 changes: 51 additions & 31 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from concurrent.futures import Future
from urllib.parse import unquote, urljoin, urlsplit

import anyio
import anyio.abc
import requests
from anyio.streams.stapled import StapledObjectStream

Expand All @@ -24,6 +24,12 @@
else: # pragma: no cover
from typing_extensions import TypedDict


_PortalFactoryType = typing.Callable[
[], typing.ContextManager[anyio.abc.BlockingPortal]
]


# Annotations for `Session.request()`
Cookies = typing.Union[
typing.MutableMapping[str, str], requests.cookies.RequestsCookieJar
Expand Down Expand Up @@ -106,14 +112,14 @@ class _ASGIAdapter(requests.adapters.HTTPAdapter):
def __init__(
self,
app: ASGI3App,
async_backend: _AsyncBackend,
portal_factory: _PortalFactoryType,
raise_server_exceptions: bool = True,
root_path: str = "",
) -> None:
self.app = app
self.raise_server_exceptions = raise_server_exceptions
self.root_path = root_path
self.async_backend = async_backend
self.portal_factory = portal_factory

def send(
self, request: requests.PreparedRequest, *args: typing.Any, **kwargs: typing.Any
Expand Down Expand Up @@ -162,7 +168,7 @@ def send(
"server": [host, port],
"subprotocols": subprotocols,
}
session = WebSocketTestSession(self.app, scope, self.async_backend)
session = WebSocketTestSession(self.app, scope, self.portal_factory)
raise _Upgrade(session)

scope = {
Expand Down Expand Up @@ -252,7 +258,7 @@ async def send(message: Message) -> None:
context = message["context"]

try:
with anyio.start_blocking_portal(**self.async_backend) as portal:
with self.portal_factory() as portal:
response_complete = portal.call(anyio.Event)
portal.call(self.app, scope, receive, send)
except BaseException as exc:
Expand Down Expand Up @@ -285,20 +291,18 @@ def __init__(
self,
app: ASGI3App,
scope: Scope,
async_backend: _AsyncBackend,
portal_factory: _PortalFactoryType,
) -> None:
self.app = app
self.scope = scope
self.accepted_subprotocol = None
self.async_backend = async_backend
self.portal_factory = portal_factory
self._receive_queue: "queue.Queue[typing.Any]" = queue.Queue()
self._send_queue: "queue.Queue[typing.Any]" = queue.Queue()

def __enter__(self) -> "WebSocketTestSession":
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)
self.portal = self.exit_stack.enter_context(self.portal_factory())

try:
_: "Future[None]" = self.portal.start_task_soon(self._run)
Expand Down Expand Up @@ -396,6 +400,7 @@ def receive_json(self, mode: str = "text") -> typing.Any:
class TestClient(requests.Session):
__test__ = False # For pytest to not discover this up.
task: "Future[None]"
portal: typing.Optional[anyio.abc.BlockingPortal] = None

def __init__(
self,
Expand All @@ -418,7 +423,7 @@ def __init__(
asgi_app = _WrapASGI2(app) #  type: ignore
adapter = _ASGIAdapter(
asgi_app,
self.async_backend,
portal_factory=self._portal_factory,
raise_server_exceptions=raise_server_exceptions,
root_path=root_path,
)
Expand All @@ -430,6 +435,16 @@ def __init__(
self.app = asgi_app
self.base_url = base_url

@contextlib.contextmanager
def _portal_factory(
self,
) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
if self.portal is not None:
yield self.portal
else:
with anyio.start_blocking_portal(**self.async_backend) as portal:
yield portal

def request( # type: ignore
self,
method: str,
Expand Down Expand Up @@ -490,29 +505,34 @@ def websocket_connect(
return session

def __enter__(self) -> "TestClient":
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)
self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
try:
self.task = self.portal.start_task_soon(self.lifespan)
self.portal.call(self.wait_startup)
except Exception:
self.exit_stack.close()
raise
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(
anyio.start_blocking_portal(**self.async_backend)
)

@stack.callback
def reset_portal() -> None:
self.portal = None

self.stream_send = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.stream_receive = StapledObjectStream(
*anyio.create_memory_object_stream(math.inf)
)
self.task = portal.start_task_soon(self.lifespan)
portal.call(self.wait_startup)

@stack.callback
def wait_shutdown() -> None:
portal.call(self.wait_shutdown)

self.exit_stack = stack.pop_all()

return self

def __exit__(self, *args: typing.Any) -> None:
try:
self.portal.call(self.wait_shutdown)
finally:
self.exit_stack.close()
self.exit_stack.close()

async def lifespan(self) -> None:
scope = {"type": "lifespan"}
Expand Down
119 changes: 106 additions & 13 deletions tests/test_testclient.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import asyncio
import itertools
import sys

import anyio
import pytest
import sniffio
import trio.lowlevel

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.responses import JSONResponse
from starlette.websockets import WebSocket, WebSocketDisconnect

if sys.version_info >= (3, 7):
from asyncio import current_task as asyncio_current_task # pragma: no cover
else:
asyncio_current_task = asyncio.Task.current_task # pragma: no cover

mock_service = Starlette()


Expand All @@ -14,16 +25,19 @@ def mock_service_endpoint(request):
return JSONResponse({"mock": "example"})


def create_app(test_client_factory):
app = Starlette()

@app.route("/")
def homepage(request):
client = test_client_factory(mock_service)
response = client.get("/")
return JSONResponse(response.json())
def current_task():
# anyio's TaskInfo comparisons are invalid after their associated native
# task object is GC'd /~https://github.com/agronholm/anyio/issues/324
asynclib_name = sniffio.current_async_library()
if asynclib_name == "trio":
return trio.lowlevel.current_task()

return app
if asynclib_name == "asyncio":
task = asyncio_current_task()
if task is None:
raise RuntimeError("must be called from a running task") # pragma: no cover
return task
raise RuntimeError(f"unsupported asynclib={asynclib_name}") # pragma: no cover


startup_error_app = Starlette()
Expand All @@ -41,14 +55,93 @@ def test_use_testclient_in_endpoint(test_client_factory):
This is useful if we need to mock out other services,
during tests or in development.
"""
client = test_client_factory(create_app(test_client_factory))

app = Starlette()

@app.route("/")
def homepage(request):
client = test_client_factory(mock_service)
response = client.get("/")
return JSONResponse(response.json())

client = test_client_factory(app)
response = client.get("/")
assert response.json() == {"mock": "example"}


def test_use_testclient_as_contextmanager(test_client_factory):
with test_client_factory(create_app(test_client_factory)):
pass
def test_use_testclient_as_contextmanager(test_client_factory, anyio_backend_name):
"""
This test asserts a number of properties that are important for an
app level task_group
"""
counter = itertools.count()
identity_runvar = anyio.lowlevel.RunVar[int]("identity_runvar")

def get_identity():
try:
return identity_runvar.get()
except LookupError:
token = next(counter)
identity_runvar.set(token)
return token

startup_task = object()
startup_loop = None
shutdown_task = object()
shutdown_loop = None

async def lifespan_context(app):
nonlocal startup_task, startup_loop, shutdown_task, shutdown_loop

startup_task = current_task()
startup_loop = get_identity()
async with anyio.create_task_group() as app.task_group:
yield
shutdown_task = current_task()
shutdown_loop = get_identity()

app = Starlette(lifespan=lifespan_context)

@app.route("/loop_id")
async def loop_id(request):
return JSONResponse(get_identity())

client = test_client_factory(app)

with client:
# within a TestClient context every async request runs in the same thread
assert client.get("/loop_id").json() == 0
assert client.get("/loop_id").json() == 0

# that thread is also the same as the lifespan thread
assert startup_loop == 0
assert shutdown_loop == 0

# lifespan events run in the same task, this is important because a task
# group must be entered and exited in the same task.
assert startup_task is shutdown_task

# outside the TestClient context, new requests continue to spawn in new
# eventloops in new threads
assert client.get("/loop_id").json() == 1
assert client.get("/loop_id").json() == 2

first_task = startup_task

with client:
# the TestClient context can be re-used, starting a new lifespan task
# in a new thread
assert client.get("/loop_id").json() == 3
assert client.get("/loop_id").json() == 3

assert startup_loop == 3
assert shutdown_loop == 3

# lifespan events still run in the same task, with the context but...
assert startup_task is shutdown_task

# ... the second TestClient context creates a new lifespan task.
assert first_task is not startup_task


def test_error_on_startup(test_client_factory):
Expand Down

0 comments on commit 254d0d9

Please sign in to comment.