Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve changes to contexvars made in threadpools #1258

Closed
wants to merge 30 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
73ed57c
Restore context from run_in_threadpool
adriangb Aug 4, 2021
8f1fa9a
linting, type annotations
adriangb Aug 4, 2021
5acec5b
Merge branch 'master' into restore-threadpool
Kludex Sep 16, 2021
9952e47
Merge branch 'master' into restore-threadpool
adriangb Sep 20, 2021
6711330
Merge branch 'master' into restore-threadpool
adriangb Dec 7, 2021
9f880b8
make method private, add tests, document test purpose
adriangb Dec 7, 2021
9976940
Merge branch 'master' into restore-threadpool
adriangb Dec 16, 2021
bc15600
Merge branch 'master' into restore-threadpool
adriangb Dec 24, 2021
abd858a
Merge branch 'master' into restore-threadpool
adriangb Jan 6, 2022
277c229
Merge branch 'master' into restore-threadpool
adriangb Feb 11, 2022
6b92487
fix linting
adriangb Feb 11, 2022
4887245
Merge branch 'master' into restore-threadpool
adriangb Feb 11, 2022
2235db6
Merge branch 'master' into restore-threadpool
adriangb Feb 17, 2022
f41e4da
add tests for behavior of contextvars.ContextVar.reset
adriangb Feb 18, 2022
013938a
Merge branch 'master' into restore-threadpool
adriangb Mar 10, 2022
21579f8
Merge branch 'master' into restore-threadpool
adriangb Apr 24, 2022
2670e32
Merge branch 'master' into restore-threadpool
adriangb May 6, 2022
fabb1b7
Merge branch 'master' into restore-threadpool
Kludex May 6, 2022
bff9a5f
Update starlette/concurrency.py
adriangb May 6, 2022
073ffbe
Update starlette/concurrency.py
adriangb May 6, 2022
9b44f1b
Update concurrency.py
adriangb May 6, 2022
09e185d
Update concurrency.py
adriangb May 6, 2022
bf68fb2
Update concurrency.py
adriangb May 6, 2022
319e20f
Merge branch 'master' into restore-threadpool
adriangb May 22, 2022
2fe3520
Merge branch 'master' into restore-threadpool
adriangb May 26, 2022
473084d
Merge branch 'master' into restore-threadpool
adriangb Jun 2, 2022
728a206
Merge branch 'master' into restore-threadpool
adriangb Jun 4, 2022
54f6643
Merge branch 'master' into restore-threadpool
adriangb Jun 14, 2022
c4f0158
Merge branch 'encode:master' into restore-threadpool
adriangb Jun 24, 2022
2d9b857
Merge branch 'master' into restore-threadpool
Kludex Aug 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion starlette/concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,25 @@
else: # pragma: no cover
from typing_extensions import ParamSpec

import contextvars
from contextvars import Context

T = typing.TypeVar("T")
P = ParamSpec("P")


def _restore_context(context: Context) -> None:
"""Copy the state of `context` to the current context."""
for cvar in context:
newval = context.get(cvar)
try:
if cvar.get() != newval:
cvar.set(newval)
except LookupError:
# the context variable was first set inside of `context`
cvar.set(newval)


async def run_until_first_complete(*args: typing.Tuple[typing.Callable, dict]) -> None:
warnings.warn(
"run_until_first_complete is deprecated "
Expand All @@ -38,7 +52,12 @@ async def run_in_threadpool(
if kwargs: # pragma: no cover
# run_sync doesn't accept 'kwargs', so bind them in here
func = functools.partial(func, **kwargs)
return await anyio.to_thread.run_sync(func, *args)
context = contextvars.copy_context()
func = functools.partial(context.run, func) # type: ignore[assignment]
res = await anyio.to_thread.run_sync(func, *args)
if context is not None:
_restore_context(context)
return res


class _StopIteration(Exception):
Expand Down
76 changes: 73 additions & 3 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from contextvars import ContextVar
import contextvars
from typing import List

import anyio
import pytest

from starlette.applications import Starlette
from starlette.concurrency import run_until_first_complete
from starlette.concurrency import run_in_threadpool, run_until_first_complete
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Route
Expand All @@ -29,7 +30,7 @@ async def task2():


def test_accessing_context_from_threaded_sync_endpoint(test_client_factory) -> None:
ctxvar: ContextVar[bytes] = ContextVar("ctxvar")
ctxvar: contextvars.ContextVar[bytes] = contextvars.ContextVar("ctxvar")
ctxvar.set(b"data")

def endpoint(request: Request) -> Response:
Expand All @@ -40,3 +41,72 @@ def endpoint(request: Request) -> Response:

resp = client.get("/")
assert resp.content == b"data"


@pytest.mark.anyio
async def test_restore_context_from_thread_previously_set():
"""Value outside of threadpool is overwitten with value set in threadpool"""
ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
ctxvar.set("spam")

def sync_task():
ctxvar.set("ham")

await run_in_threadpool(sync_task)
assert ctxvar.get() == "ham"


@pytest.mark.anyio
async def test_restore_context_from_thread_previously_unset():
"""Value outside of threadpool is set to value in threadpool"""
ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")

def sync_task():
ctxvar.set("ham")

await run_in_threadpool(sync_task)
assert ctxvar.get() == "ham"


@pytest.mark.anyio
async def test_restore_context_from_thread_new_cvar():
"""Value outside of threadpool is set for a cvar created in the threadpool"""
ctxvars: List[contextvars.ContextVar[str]] = []

def sync_task():
ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
ctxvar.set("ham")
ctxvars.append(ctxvar)

await run_in_threadpool(sync_task)
assert len(ctxvars) == 1
assert next(iter(ctxvars)).get() == "ham"
Comment on lines +46 to +83
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This are the failing tests, which get to the root of the problem in the PR post without jumping through hoops to exercise the issue



@pytest.mark.anyio
async def test_restore_context_from_thread_reset_token_in_child_context():
ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
ctxvar.set("spam")

def sync_task():
token = ctxvar.set("ham")
ctxvar.reset(token)

await run_in_threadpool(sync_task)
assert ctxvar.get() == "spam"


@pytest.mark.anyio
async def test_restore_context_from_thread_reset_token_in_parent_context():
ctxvar: contextvars.ContextVar[str] = contextvars.ContextVar("ctxvar")
tokens: List[contextvars.Token[str]] = []

def sync_task():
# this token gets created in the child context
# and hence can't be restored in the parent context
tokens.append(ctxvar.set("ham"))

await run_in_threadpool(sync_task)
assert ctxvar.get() == "ham"
with pytest.raises(ValueError, match="was created in a different Context"):
ctxvar.reset(tokens.pop())