From 95951dbc005ec993989162b38f1cb92847d5c178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Thu, 13 Jun 2024 20:57:44 +0200 Subject: [PATCH] fix(testing): `.websocket_connect` does not respect `base_url` (#3567) * Fix WS client base url --- litestar/testing/client/async_client.py | 397 +----------------- litestar/testing/client/base.py | 37 +- litestar/testing/client/sync_client.py | 397 +----------------- tests/unit/test_connection/test_websocket.py | 15 +- tests/unit/test_contrib/test_opentelemetry.py | 6 +- 5 files changed, 73 insertions(+), 779 deletions(-) diff --git a/litestar/testing/client/async_client.py b/litestar/testing/client/async_client.py index 0e4d779170..4e28bef4ac 100644 --- a/litestar/testing/client/async_client.py +++ b/litestar/testing/client/async_client.py @@ -2,11 +2,9 @@ from contextlib import AsyncExitStack from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar -from urllib.parse import urljoin -from httpx import USE_CLIENT_DEFAULT, AsyncClient, Response +from httpx import USE_CLIENT_DEFAULT, AsyncClient -from litestar import HttpMethod from litestar.testing.client.base import BaseTestClient from litestar.testing.life_span_handler import LifeSpanHandler from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport @@ -19,11 +17,7 @@ CookieTypes, HeaderTypes, QueryParamTypes, - RequestContent, - RequestData, - RequestFiles, TimeoutTypes, - URLTypes, ) from typing_extensions import Self @@ -107,369 +101,6 @@ def wait_shutdown() -> None: async def __aexit__(self, *args: Any) -> None: await self.exit_stack.aclose() - async def request( - self, - method: str, - url: URLTypes, - *, - content: RequestContent | None = None, - data: RequestData | None = None, - files: RequestFiles | None = None, - json: Any | None = None, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a request. - - Args: - method: An HTTP method. - url: URL or path for the request. - content: Request content. - data: Form encoded data. - files: Multipart files to send. - json: JSON data to send. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return await AsyncClient.request( - self, - url=self.base_url.join(url), - method=method.value if isinstance(method, HttpMethod) else method, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - async def get( # type: ignore [override] - self, - url: URLTypes, - *, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a GET request. - - Args: - url: URL or path for the request. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return await AsyncClient.get( - self, - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - async def options( - self, - url: URLTypes, - *, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends an OPTIONS request. - - Args: - url: URL or path for the request. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return await AsyncClient.options( - self, - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - async def head( - self, - url: URLTypes, - *, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a HEAD request. - - Args: - url: URL or path for the request. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return await AsyncClient.head( - self, - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - async def post( - self, - url: URLTypes, - *, - content: RequestContent | None = None, - data: RequestData | None = None, - files: RequestFiles | None = None, - json: Any | None = None, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a POST request. - - Args: - url: URL or path for the request. - content: Request content. - data: Form encoded data. - files: Multipart files to send. - json: JSON data to send. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return await AsyncClient.post( - self, - url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - async def put( - self, - url: URLTypes, - *, - content: RequestContent | None = None, - data: RequestData | None = None, - files: RequestFiles | None = None, - json: Any | None = None, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a PUT request. - - Args: - url: URL or path for the request. - content: Request content. - data: Form encoded data. - files: Multipart files to send. - json: JSON data to send. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return await AsyncClient.put( - self, - url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - async def patch( - self, - url: URLTypes, - *, - content: RequestContent | None = None, - data: RequestData | None = None, - files: RequestFiles | None = None, - json: Any | None = None, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a PATCH request. - - Args: - url: URL or path for the request. - content: Request content. - data: Form encoded data. - files: Multipart files to send. - json: JSON data to send. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return await AsyncClient.patch( - self, - url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - async def delete( - self, - url: URLTypes, - *, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a DELETE request. - - Args: - url: URL or path for the request. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return await AsyncClient.delete( - self, - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - async def websocket_connect( self, url: str, @@ -498,25 +129,19 @@ async def websocket_connect( Returns: A `WebSocketTestSession ` instance. """ - url = urljoin("ws://testserver", url) - default_headers: dict[str, str] = {} - default_headers.setdefault("connection", "upgrade") - default_headers.setdefault("sec-websocket-key", "testserver==") - default_headers.setdefault("sec-websocket-version", "13") - if subprotocols is not None: - default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) try: - await AsyncClient.request( - self, - "GET", - url, - headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] - params=params, - cookies=cookies, + await self.send( + self._prepare_ws_connect_request( + url=url, + subprotocols=subprotocols, + params=params, + headers=headers, + cookies=cookies, + extensions=extensions, + timeout=timeout, + ), auth=auth, follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), ) except ConnectionUpgradeExceptionError as exc: return exc.session diff --git a/litestar/testing/client/base.py b/litestar/testing/client/base.py index 3c25be117b..ddaed17935 100644 --- a/litestar/testing/client/base.py +++ b/litestar/testing/client/base.py @@ -2,11 +2,13 @@ from contextlib import contextmanager from http.cookiejar import CookieJar -from typing import TYPE_CHECKING, Any, Generator, Generic, Mapping, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generator, Generic, Mapping, Sequence, TypeVar, cast from warnings import warn +import httpx from anyio.from_thread import BlockingPortal, start_blocking_portal from httpx import Cookies, Request, Response +from httpx._client import USE_CLIENT_DEFAULT, BaseClient, UseClientDefault from litestar import Litestar from litestar.connection import ASGIConnection @@ -19,7 +21,12 @@ from litestar.utils.scope.state import ScopeState if TYPE_CHECKING: - from httpx._types import CookieTypes + from httpx._types import ( + CookieTypes, + HeaderTypes, + QueryParamTypes, + TimeoutTypes, + ) from litestar.middleware.session.base import BaseBackendConfig, BaseSessionBackend from litestar.types.asgi_types import HTTPScope, Receive, Scope, Send @@ -178,3 +185,29 @@ async def _get_session_data(self) -> dict[str, Any]: cookies=dict(self.cookies), # type: ignore[arg-type] ), ) + + def _prepare_ws_connect_request( # type: ignore[misc] + self: BaseClient, # pyright: ignore + url: str, + subprotocols: Sequence[str] | None = None, + params: QueryParamTypes | None = None, + headers: HeaderTypes | None = None, + cookies: CookieTypes | None = None, + timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + extensions: Mapping[str, Any] | None = None, + ) -> httpx.Request: + default_headers: dict[str, str] = {} + default_headers.setdefault("connection", "upgrade") + default_headers.setdefault("sec-websocket-key", "testserver==") + default_headers.setdefault("sec-websocket-version", "13") + if subprotocols is not None: + default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) + return self.build_request( + "GET", + self.base_url.copy_with(scheme="ws").join(url), + headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] + params=params, + cookies=cookies, + extensions=None if extensions is None else dict(extensions), + timeout=timeout, + ) diff --git a/litestar/testing/client/sync_client.py b/litestar/testing/client/sync_client.py index d90705646b..9cbfcb3d94 100644 --- a/litestar/testing/client/sync_client.py +++ b/litestar/testing/client/sync_client.py @@ -2,11 +2,9 @@ from contextlib import ExitStack from typing import TYPE_CHECKING, Any, Generic, Mapping, Sequence, TypeVar -from urllib.parse import urljoin -from httpx import USE_CLIENT_DEFAULT, Client, Response +from httpx import USE_CLIENT_DEFAULT, Client -from litestar import HttpMethod from litestar.testing.client.base import BaseTestClient from litestar.testing.life_span_handler import LifeSpanHandler from litestar.testing.transport import ConnectionUpgradeExceptionError, TestClientTransport @@ -19,11 +17,7 @@ CookieTypes, HeaderTypes, QueryParamTypes, - RequestContent, - RequestData, - RequestFiles, TimeoutTypes, - URLTypes, ) from typing_extensions import Self @@ -109,369 +103,6 @@ def wait_shutdown() -> None: def __exit__(self, *args: Any) -> None: self.exit_stack.close() - def request( - self, - method: str, - url: URLTypes, - *, - content: RequestContent | None = None, - data: RequestData | None = None, - files: RequestFiles | None = None, - json: Any | None = None, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a request. - - Args: - method: An HTTP method. - url: URL or path for the request. - content: Request content. - data: Form encoded data. - files: Multipart files to send. - json: JSON data to send. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return Client.request( - self, - url=self.base_url.join(url), - method=method.value if isinstance(method, HttpMethod) else method, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - def get( - self, - url: URLTypes, - *, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a GET request. - - Args: - url: URL or path for the request. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return Client.get( - self, - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - def options( - self, - url: URLTypes, - *, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends an OPTIONS request. - - Args: - url: URL or path for the request. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return Client.options( - self, - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - def head( - self, - url: URLTypes, - *, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a HEAD request. - - Args: - url: URL or path for the request. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return Client.head( - self, - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - def post( - self, - url: URLTypes, - *, - content: RequestContent | None = None, - data: RequestData | None = None, - files: RequestFiles | None = None, - json: Any | None = None, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a POST request. - - Args: - url: URL or path for the request. - content: Request content. - data: Form encoded data. - files: Multipart files to send. - json: JSON data to send. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return Client.post( - self, - url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - def put( - self, - url: URLTypes, - *, - content: RequestContent | None = None, - data: RequestData | None = None, - files: RequestFiles | None = None, - json: Any | None = None, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a PUT request. - - Args: - url: URL or path for the request. - content: Request content. - data: Form encoded data. - files: Multipart files to send. - json: JSON data to send. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return Client.put( - self, - url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - def patch( - self, - url: URLTypes, - *, - content: RequestContent | None = None, - data: RequestData | None = None, - files: RequestFiles | None = None, - json: Any | None = None, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a PATCH request. - - Args: - url: URL or path for the request. - content: Request content. - data: Form encoded data. - files: Multipart files to send. - json: JSON data to send. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return Client.patch( - self, - url, - content=content, - data=data, - files=files, - json=json, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - - def delete( - self, - url: URLTypes, - *, - params: QueryParamTypes | None = None, - headers: HeaderTypes | None = None, - cookies: CookieTypes | None = None, - auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, - follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, - timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, - extensions: Mapping[str, Any] | None = None, - ) -> Response: - """Sends a DELETE request. - - Args: - url: URL or path for the request. - params: Query parameters. - headers: Request headers. - cookies: Request cookies. - auth: Auth headers. - follow_redirects: Whether to follow redirects. - timeout: Request timeout. - extensions: Dictionary of ASGI extensions. - - Returns: - An HTTPX Response. - """ - return Client.delete( - self, - url, - params=params, - headers=headers, - cookies=cookies, - auth=auth, - follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), - ) - def websocket_connect( self, url: str, @@ -500,25 +131,19 @@ def websocket_connect( Returns: A `WebSocketTestSession ` instance. """ - url = urljoin("ws://testserver", url) - default_headers: dict[str, str] = {} - default_headers.setdefault("connection", "upgrade") - default_headers.setdefault("sec-websocket-key", "testserver==") - default_headers.setdefault("sec-websocket-version", "13") - if subprotocols is not None: - default_headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols)) try: - Client.request( - self, - "GET", - url, - headers={**dict(headers or {}), **default_headers}, # type: ignore[misc] - params=params, - cookies=cookies, + self.send( + self._prepare_ws_connect_request( + url=url, + subprotocols=subprotocols, + params=params, + headers=headers, + cookies=cookies, + extensions=extensions, + timeout=timeout, + ), auth=auth, follow_redirects=follow_redirects, - timeout=timeout, - extensions=None if extensions is None else dict(extensions), ) except ConnectionUpgradeExceptionError as exc: return exc.session diff --git a/tests/unit/test_connection/test_websocket.py b/tests/unit/test_connection/test_websocket.py index 840d64b0fb..5e7a2e2d38 100644 --- a/tests/unit/test_connection/test_websocket.py +++ b/tests/unit/test_connection/test_websocket.py @@ -92,7 +92,18 @@ async def handler(socket: WebSocket) -> None: await socket.close() with create_test_client(handler).websocket_connect("/123?a=abc") as ws: - assert ws.receive_json() == {"url": "ws://testserver/123?a=abc"} + assert ws.receive_json() == {"url": "ws://testserver.local/123?a=abc"} + + +def test_websocket_url_respects_custom_base_url() -> None: + @websocket("/123") + async def handler(socket: WebSocket) -> None: + await socket.accept() + await socket.send_json({"url": str(socket.url)}) + await socket.close() + + with create_test_client(handler, base_url="http://example.org").websocket_connect("/123?a=abc") as ws: + assert ws.receive_json() == {"url": "ws://example.org/123?a=abc"} def test_websocket_binary_json() -> None: @@ -133,7 +144,7 @@ async def handler(socket: WebSocket) -> None: "accept": "*/*", "accept-encoding": "gzip, deflate, br", "connection": "upgrade", - "host": "testserver", + "host": "testserver.local", "user-agent": "testclient", "sec-websocket-key": "testserver==", "sec-websocket-version": "13", diff --git a/tests/unit/test_contrib/test_opentelemetry.py b/tests/unit/test_contrib/test_opentelemetry.py index ef16140b5e..4d4d611887 100644 --- a/tests/unit/test_contrib/test_opentelemetry.py +++ b/tests/unit/test_contrib/test_opentelemetry.py @@ -113,11 +113,11 @@ async def handler(socket: "WebSocket") -> None: assert dict(fourth_span.attributes) == {"type": "websocket.close"} # type: ignore[arg-type] assert dict(fifth_span.attributes) == { # type: ignore[arg-type] "http.scheme": "ws", - "http.host": "testserver", + "http.host": "testserver.local", "net.host.port": 80, "http.target": "/", - "http.url": "ws://testserver/", - "http.server_name": "testserver", + "http.url": "ws://testserver.local/", + "http.server_name": "testserver.local", "http.user_agent": "testclient", "net.peer.ip": "testclient", "net.peer.port": 50000,