Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reason to WebSocket closure #1417

Merged
merged 20 commits into from
Jan 22, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/websockets.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Use `websocket.receive_json(data, mode="binary")` to receive JSON over binary da

### Closing the connection

* `await websocket.close(code=1000)`
* `await websocket.close(code=1000, reason=None)`

### Sending and receiving messages

Expand Down
2 changes: 1 addition & 1 deletion starlette/testclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ async def _asgi_send(self, message: Message) -> None:

def _raise_on_close(self, message: Message) -> None:
if message["type"] == "websocket.close":
raise WebSocketDisconnect(message.get("code", 1000))
raise WebSocketDisconnect(message.get("code", 1000), message.get("reason"))
aminalaee marked this conversation as resolved.
Show resolved Hide resolved

def send(self, message: Message) -> None:
self._receive_queue.put(message)
Expand Down
14 changes: 9 additions & 5 deletions starlette/websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ class WebSocketState(enum.Enum):


class WebSocketDisconnect(Exception):
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = 1000, reason: str = None) -> None:
self.code = code
self.reason = reason


class WebSocket(HTTPConnection):
Expand Down Expand Up @@ -144,13 +145,16 @@ async def send_json(self, data: typing.Any, mode: str = "text") -> None:
else:
await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})

async def close(self, code: int = 1000) -> None:
await self.send({"type": "websocket.close", "code": code})
async def close(self, code: int = 1000, reason: str = None) -> None:
await self.send({"type": "websocket.close", "code": code, "reason": reason})
aminalaee marked this conversation as resolved.
Show resolved Hide resolved


class WebSocketClose:
def __init__(self, code: int = 1000) -> None:
def __init__(self, code: int = 1000, reason: str = None) -> None:
self.code = code
self.reason = reason

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "websocket.close", "code": self.code})
await send(
{"type": "websocket.close", "code": self.code, "reason": self.reason}
)
38 changes: 37 additions & 1 deletion tests/test_websockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest

from starlette import status
from starlette.websockets import WebSocket, WebSocketDisconnect
from starlette.websockets import WebSocket, WebSocketClose, WebSocketDisconnect


def test_websocket_url(test_client_factory):
Expand Down Expand Up @@ -391,3 +391,39 @@ async def mock_send(message):
assert websocket == websocket
assert websocket in {websocket}
assert {websocket} == {websocket}


def test_websocket_close_reason(test_client_factory) -> None:
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()
await websocket.close(code=1001, reason="Closing")

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Closing"


def test_websocket_close_reason_manual(test_client_factory) -> None:
def app(scope):
async def asgi(receive, send):
websocket = WebSocket(scope, receive=receive, send=send)
await websocket.accept()

websocket_close = WebSocketClose(code=1001, reason="Closing")
aminalaee marked this conversation as resolved.
Show resolved Hide resolved
await websocket_close(scope, receive, send)

return asgi

client = test_client_factory(app)
with client.websocket_connect("/") as websocket:
with pytest.raises(WebSocketDisconnect) as exc:
websocket.receive_text()
assert exc.value.code == status.WS_1001_GOING_AWAY
assert exc.value.reason == "Closing"