diff --git a/starlette/concurrency.py b/starlette/concurrency.py index 5c76cb3df..4bb9c302b 100644 --- a/starlette/concurrency.py +++ b/starlette/concurrency.py @@ -10,11 +10,25 @@ else: # pragma: no cover from typing_extensions import ParamSpec +import contextvars +from contextvars import Context T = typing.TypeVar("T") P = ParamSpec("P") +def _restore_context(context: Context) -> None: + """Copy the state of `context` to the current context.""" + for cvar in context: + newval = context.get(cvar) + try: + if cvar.get() != newval: + cvar.set(newval) + except LookupError: + # the context variable was first set inside of `context` + cvar.set(newval) + + async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None: warnings.warn( "run_until_first_complete is deprecated " @@ -38,7 +52,12 @@ async def run_in_threadpool( if kwargs: # pragma: no cover # run_sync doesn't accept 'kwargs', so bind them in here func = functools.partial(func, **kwargs) - return await anyio.to_thread.run_sync(func, *args) + context = contextvars.copy_context() + func = functools.partial(context.run, func) # type: ignore[assignment] + res = await anyio.to_thread.run_sync(func, *args) + if context is not None: + _restore_context(context) + return res class _StopIteration(Exception): diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 22b9da0e8..f13f9db24 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -1,10 +1,11 @@ -from contextvars import ContextVar +import contextvars +from typing import List import anyio import pytest from starlette.applications import Starlette -from starlette.concurrency import run_until_first_complete +from starlette.concurrency import run_in_threadpool, run_until_first_complete from starlette.requests import Request from starlette.responses import Response from starlette.routing import Route @@ -29,7 +30,7 @@ async def task2(): def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> None: - ctxvar: ContextVar[bytes] = ContextVar("ctxvar") + ctxvar: contextvars.ContextVar[bytes] = contextvars.ContextVar("ctxvar") ctxvar.set(b"data") def endpoint(request: Request) -> Response: @@ -40,3 +41,72 @@ def endpoint(request: Request) -> Response: resp = client.get("/") assert resp.content == b"data" + + +@pytest.mark.anyio +async def test_restore_context_from_thread_previously_set(): + """Value outside of threadpool is overwitten with value set in threadpool""" + ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") + ctxvar.set("spam") + + def sync_task(): + ctxvar.set("ham") + + await run_in_threadpool(sync_task) + assert ctxvar.get() == "ham" + + +@pytest.mark.anyio +async def test_restore_context_from_thread_previously_unset(): + """Value outside of threadpool is set to value in threadpool""" + ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") + + def sync_task(): + ctxvar.set("ham") + + await run_in_threadpool(sync_task) + assert ctxvar.get() == "ham" + + +@pytest.mark.anyio +async def test_restore_context_from_thread_new_cvar(): + """Value outside of threadpool is set for a cvar created in the threadpool""" + ctxvars: List[contextvars.ContextVar[str]] = [] + + def sync_task(): + ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") + ctxvar.set("ham") + ctxvars.append(ctxvar) + + await run_in_threadpool(sync_task) + assert len(ctxvars) == 1 + assert next(iter(ctxvars)).get() == "ham" + + +@pytest.mark.anyio +async def test_restore_context_from_thread_reset_token_in_child_context(): + ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") + ctxvar.set("spam") + + def sync_task(): + token = ctxvar.set("ham") + ctxvar.reset(token) + + await run_in_threadpool(sync_task) + assert ctxvar.get() == "spam" + + +@pytest.mark.anyio +async def test_restore_context_from_thread_reset_token_in_parent_context(): + ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar") + tokens: List[contextvars.Token[str]] = [] + + def sync_task(): + # this token gets created in the child context + # and hence can't be restored in the parent context + tokens.append(ctxvar.set("ham")) + + await run_in_threadpool(sync_task) + assert ctxvar.get() == "ham" + with pytest.raises(ValueError, match="was created in a different Context"): + ctxvar.reset(tokens.pop())