From 480fb3edc501b85ac1d4930c163819eacf40ab24 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 23 Jun 2021 21:14:39 +0100 Subject: [PATCH] remove monkeypatching TestClient interface --- starlette/testclient.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/starlette/testclient.py b/starlette/testclient.py index 7201809e26..3a58c18e0e 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -91,11 +91,16 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await instance(receive, send) +class _AsyncBackend(typing.TypedDict): + backend: str + backend_options: typing.Dict[str, typing.Any] + + class _ASGIAdapter(requests.adapters.HTTPAdapter): def __init__( self, app: ASGI3App, - async_backend: typing.Dict[str, typing.Any], + async_backend: _AsyncBackend, raise_server_exceptions: bool = True, root_path: str = "", ) -> None: @@ -271,7 +276,10 @@ async def send(message: Message) -> None: class WebSocketTestSession: def __init__( - self, app: ASGI3App, scope: Scope, async_backend: typing.Dict[str, typing.Any] + self, + app: ASGI3App, + scope: Scope, + async_backend: _AsyncBackend, ) -> None: self.app = app self.scope = scope @@ -381,11 +389,6 @@ def receive_json(self, mode: str = "text") -> typing.Any: class TestClient(requests.Session): __test__ = False # For pytest to not discover this up. - #: These are the default options for the constructor arguments - async_backend: typing.Dict[str, typing.Any] = { - "backend": "asyncio", - "backend_options": {}, - } task: "Future[None]" def __init__( @@ -394,14 +397,13 @@ def __init__( base_url: str = "http://testserver", raise_server_exceptions: bool = True, root_path: str = "", - backend: typing.Optional[str] = None, + backend: str = "asyncio", backend_options: typing.Optional[typing.Dict[str, typing.Any]] = None, ) -> None: super().__init__() - self.async_backend = { - "backend": backend or self.async_backend["backend"], - "backend_options": backend_options or self.async_backend["backend_options"], - } + self.async_backend = _AsyncBackend( + backend=backend, backend_options=backend_options or {} + ) if _is_asgi3(app): app = typing.cast(ASGI3App, app) asgi_app = app