Skip to content

Commit

Permalink
workaround
Browse files Browse the repository at this point in the history
  • Loading branch information
shughes-uk committed Nov 25, 2021
1 parent b941db3 commit e555890
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 51 deletions.
101 changes: 51 additions & 50 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from typing import Any, Awaitable, Callable, Dict, Generic, Optional, TypeVar, overload
from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar, Union, overload

from typing_extensions import ParamSpec

Expand All @@ -20,6 +20,9 @@
else:
contextvars = None

P = ParamSpec("P")
R = TypeVar("R")


def _restore_context(context):
# Check for changes in contextvars, and set them to the current
Expand Down Expand Up @@ -100,11 +103,7 @@ async def __aexit__(self, exc, value, tb):
pass


a_cls_params = ParamSpec("a_cls_params")
a_cls_return = TypeVar("a_cls_return")


class AsyncToSync(Generic[a_cls_params, a_cls_return]):
class AsyncToSync:
"""
Utility class which turns an awaitable that only works on the thread with
the event loop into a synchronous callable that works in a subthread.
Expand All @@ -125,9 +124,7 @@ class AsyncToSync(Generic[a_cls_params, a_cls_return]):
executors = Local()

def __init__(
self,
awaitable: Callable[a_cls_params, Awaitable[a_cls_return]],
force_new_loop=False,
self, awaitable: Callable[..., Awaitable[Any]], force_new_loop: bool = False
):
if not callable(awaitable) or not _iscoroutinefunction_or_partial(awaitable):
# Python does not have very reliable detection of async functions
Expand All @@ -137,7 +134,7 @@ def __init__(
)
self.awaitable = awaitable
try:
self.__self__ = self.awaitable.__self__
self.__self__ = self.awaitable.__self__ # type: ignore
except AttributeError:
pass
if force_new_loop:
Expand All @@ -161,9 +158,7 @@ def __init__(
else:
self.main_event_loop = None

def __call__(
self, *args: a_cls_params.args, **kwargs: a_cls_params.kwargs
) -> a_cls_return:
def __call__(self, *args, **kwargs):
# You can't call AsyncToSync from a thread with a running event loop
try:
event_loop = get_running_loop()
Expand All @@ -184,7 +179,7 @@ def __call__(
context = None

# Make a future for the return information
call_result: Future[a_cls_return] = Future()
call_result = Future()
# Get the source thread
source_thread = threading.current_thread()
# Make a CurrentThreadExecutor we'll use to idle in this thread - we
Expand Down Expand Up @@ -283,13 +278,7 @@ def __get__(self, parent, objtype):
return functools.update_wrapper(func, self.awaitable)

async def main_wrap(
self,
args,
kwargs,
call_result: a_cls_return,
source_thread,
exc_info,
context,
self, args, kwargs, call_result, source_thread, exc_info, context
):
"""
Wraps the awaitable with something that puts the result into the
Expand Down Expand Up @@ -321,11 +310,7 @@ async def main_wrap(
context[0] = contextvars.copy_context()


s_cls_params = ParamSpec("s_cls_params")
s_cls_return = TypeVar("s_cls_return")


class SyncToAsync(Generic[s_cls_params, s_cls_return]):
class SyncToAsync:
"""
Utility class which turns a synchronous callable into an awaitable that
runs in a threadpool. It also sets a threadlocal inside the thread so
Expand Down Expand Up @@ -391,7 +376,7 @@ class SyncToAsync(Generic[s_cls_params, s_cls_return]):

def __init__(
self,
func: Callable[s_cls_params, s_cls_return],
func: Callable[..., Any],
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> None:
Expand All @@ -409,9 +394,7 @@ def __init__(
except AttributeError:
pass

async def __call__(
self, *args: s_cls_params.args, **kwargs: s_cls_params.kwargs
) -> s_cls_return:
async def __call__(self, *args, **kwargs):
loop = get_running_loop()

# Work out what thread to run the code in
Expand Down Expand Up @@ -485,15 +468,7 @@ def __get__(self, parent, objtype):
"""
return functools.partial(self.__call__, parent)

def thread_handler(
self,
loop,
source_task,
exc_info,
func: Callable[s_cls_params, s_cls_return],
*args: s_cls_params.args,
**kwargs: s_cls_params.kwargs
):
def thread_handler(self, loop, source_task, exc_info, func, *args, **kwargs):
"""
Wraps the sync application with exception handling.
"""
Expand Down Expand Up @@ -539,37 +514,63 @@ def get_current_task():
return None


# Lowercase aliases (and decorator friendliness)
async_to_sync = AsyncToSync
# Lowercase aliases (and decorator/typing friendliness)
@overload
def async_to_sync(
func: None = None,
force_new_loop: bool = False,
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, R]]:
...


s_params = ParamSpec("s_params")
s_return = TypeVar("s_return")
@overload
def async_to_sync(
func: Callable[P, Awaitable[R]],
force_new_loop: bool = False,
) -> Callable[P, R]:
...


def async_to_sync(
func: Optional[Callable[P, Awaitable[R]]] = None,
force_new_loop: bool = False,
) -> Union[Callable[P, R], Callable[[Callable[P, Awaitable[R]]], Callable[P, R]]]:
if func is None:
return lambda f: AsyncToSync(
f,
force_new_loop=force_new_loop,
)
return AsyncToSync(
func,
force_new_loop=force_new_loop,
)


@overload
def sync_to_async(
func: None = None,
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> Callable[[Callable[s_params, s_return]], SyncToAsync[s_params, s_return]]:
) -> Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]:
...


@overload
def sync_to_async(
func: Callable[s_params, s_return],
func: Callable[P, R],
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> SyncToAsync[s_params, s_return]:
) -> Callable[P, Awaitable[R]]:
...


def sync_to_async(
func=None,
thread_sensitive=True,
executor=None,
):
func: Optional[Callable[P, R]] = None,
thread_sensitive: bool = True,
executor: Optional["ThreadPoolExecutor"] = None,
) -> Union[
Callable[P, Awaitable[R]], Callable[[Callable[P, R]], Callable[P, Awaitable[R]]]
]:
if func is None:
return lambda f: SyncToAsync(
f,
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ zip_safe = false
tests =
pytest
pytest-asyncio
mypy>=0.800
mypy @ git+ssh://git@github.com/python/mypy.git@master

[tool:pytest]
testpaths = tests
Expand Down

0 comments on commit e555890

Please sign in to comment.