Skip to content

Commit

Permalink
feat(ws): return error when sending/receiving after websocket disconn…
Browse files Browse the repository at this point in the history
…ect (#99)
  • Loading branch information
0x676e67 authored Feb 28, 2025
1 parent 92a1fc9 commit d4cf2f0
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 93 deletions.
4 changes: 2 additions & 2 deletions rnet.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -796,7 +796,7 @@ class ClientParams:
base_url: typing.Optional[builtins.str]
user_agent: typing.Optional[builtins.str]
default_headers: typing.Optional[typing.Dict[str, bytes]]
headers_order: typing.Optional[builtins.list[builtins.str]]
headers_order: typing.Optional[typing.List[str]]
referer: typing.Optional[builtins.bool]
allow_redirects: typing.Optional[builtins.bool]
max_redirects: typing.Optional[builtins.int]
Expand Down Expand Up @@ -1378,7 +1378,7 @@ class UpdateClientParams:
impersonate_skip_http2: typing.Optional[builtins.bool]
impersonate_skip_headers: typing.Optional[builtins.bool]
headers: typing.Optional[typing.Dict[str, bytes]]
headers_order: typing.Optional[builtins.list[builtins.str]]
headers_order: typing.Optional[typing.List[str]]
proxies: typing.Optional[builtins.list[Proxy]]
local_address: typing.Optional[typing.Optional[typing.Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address]]]
interface: typing.Optional[builtins.str]
Expand Down
102 changes: 58 additions & 44 deletions src/async_impl/ws/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mod message;

use crate::{
error::{py_stop_async_iteration_error, wrap_rquest_error},
error::{py_stop_async_iteration_error, websocket_disconnect_error, wrap_rquest_error},
typing::{HeaderMap, SocketAddr, StatusCode, Version},
};
use futures_util::{
Expand Down Expand Up @@ -56,17 +56,65 @@ impl WebSocket {
}

impl WebSocket {
/// Returns the sender of the WebSocket.
#[inline(always)]
pub fn sender(&self) -> Sender {
self.sender.clone()
}

/// Returns the receiver of the WebSocket.
#[inline(always)]
pub fn receiver(&self) -> Receiver {
self.receiver.clone()
}

pub async fn _recv(receiver: Receiver) -> PyResult<Option<Message>> {
let mut lock = receiver.lock().await;
lock.as_mut()
.ok_or_else(websocket_disconnect_error)?
.try_next()
.await
.map(|val| val.map(Message))
.map_err(wrap_rquest_error)
}

pub async fn _send(sender: Sender, message: Message) -> PyResult<()> {
let mut lock = sender.lock().await;
lock.as_mut()
.ok_or_else(websocket_disconnect_error)?
.send(message.0)
.await
.map_err(wrap_rquest_error)
}

pub async fn _close(
receiver: Receiver,
sender: Sender,
code: Option<u16>,
reason: Option<String>,
) -> PyResult<()> {
let mut lock = receiver.lock().await;
let receiver = lock.take();
drop(lock);

let mut lock = sender.lock().await;
let sender = lock.take();
drop(lock);

let (receiver, mut sender) = receiver
.zip(sender)
.ok_or_else(websocket_disconnect_error)?;
drop(receiver);

if let Some(code) = code {
sender
.send(rquest::Message::Close {
code: rquest::CloseCode::from(code),
reason,
})
.await
.map_err(wrap_rquest_error)?;
}
sender.close().await.map_err(wrap_rquest_error)
}
}

#[gen_stub_pymethods]
Expand Down Expand Up @@ -143,6 +191,7 @@ impl WebSocket {
/// # Returns
///
/// An optional string representing the WebSocket protocol.
#[inline(always)]
pub fn protocol(&self) -> Option<&str> {
self.protocol.as_deref()
}
Expand All @@ -156,17 +205,9 @@ impl WebSocket {
/// # Returns
///
/// A `PyResult` containing a `Bound` object with the received message, or `None` if no message is received.
#[inline(always)]
pub fn recv<'rt>(&self, py: Python<'rt>) -> PyResult<Bound<'rt, PyAny>> {
let websocket = self.receiver.clone();
future_into_py(py, async move {
let mut lock = websocket.lock().await;
if let Some(recv) = lock.as_mut() {
if let Ok(Some(val)) = recv.try_next().await {
return Ok(Some(Message(val)));
}
}
Ok(None)
})
future_into_py(py, Self::_recv(self.receiver.clone()))
}

/// Sends a message to the WebSocket.
Expand All @@ -180,15 +221,9 @@ impl WebSocket {
///
/// A `PyResult` containing a `Bound` object.
#[pyo3(signature = (message))]
#[inline(always)]
pub fn send<'rt>(&self, py: Python<'rt>, message: Message) -> PyResult<Bound<'rt, PyAny>> {
let sender = self.sender.clone();
future_into_py(py, async move {
let mut lock = sender.lock().await;
if let Some(send) = lock.as_mut() {
return send.send(message.0).await.map_err(wrap_rquest_error);
}
Ok(())
})
future_into_py(py, Self::_send(self.sender.clone(), message))
}

/// Closes the WebSocket connection.
Expand All @@ -203,6 +238,7 @@ impl WebSocket {
///
/// A `PyResult` containing a `Bound` object.
#[pyo3(signature = (code=None, reason=None))]
#[inline(always)]
pub fn close<'rt>(
&self,
py: Python<'rt>,
Expand All @@ -211,29 +247,7 @@ impl WebSocket {
) -> PyResult<Bound<'rt, PyAny>> {
let sender = self.sender.clone();
let receiver = self.receiver.clone();
future_into_py(py, async move {
let mut lock = receiver.lock().await;
drop(lock.take());
drop(lock);

let mut lock = sender.lock().await;
let send = lock.take();
drop(lock);

if let Some(mut send) = send {
if let Some(code) = code {
send.send(rquest::Message::Close {
code: rquest::CloseCode::from(code),
reason,
})
.await
.map_err(wrap_rquest_error)?;
}
return send.close().await.map_err(wrap_rquest_error);
}

Ok(())
})
future_into_py(py, Self::_close(receiver, sender, code, reason))
}
}

Expand Down
63 changes: 16 additions & 47 deletions src/blocking/ws.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::ops::Deref;

use crate::{
async_impl::{self, Message},
async_impl::{self, Message, WebSocket},
error::{py_stop_iteration_error, wrap_rquest_error},
typing::{HeaderMap, SocketAddr, StatusCode, Version},
};
use futures_util::{SinkExt, TryStreamExt};
use futures_util::TryStreamExt;
use pyo3::prelude::*;
use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods};
use std::ops::Deref;

/// A blocking WebSocket response.
#[gen_stub_pyclass]
Expand Down Expand Up @@ -102,6 +101,7 @@ impl BlockingWebSocket {
/// # Returns
///
/// An optional string representing the WebSocket protocol.
#[inline(always)]
pub fn protocol(&self) -> Option<&str> {
self.0.protocol()
}
Expand All @@ -111,18 +111,10 @@ impl BlockingWebSocket {
/// # Returns
///
/// A `PyResult` containing a `Bound` object with the received message, or `None` if no message is received.
#[inline(always)]
pub fn recv(&self, py: Python) -> PyResult<Option<Message>> {
py.allow_threads(|| {
let websocket = self.receiver();
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let mut lock = websocket.lock().await;
if let Some(recv) = lock.as_mut() {
if let Ok(Some(val)) = recv.try_next().await {
return Ok(Some(Message(val)));
}
}
Ok(None)
})
pyo3_async_runtimes::tokio::get_runtime().block_on(WebSocket::_recv(self.receiver()))
})
}

Expand All @@ -136,16 +128,11 @@ impl BlockingWebSocket {
///
/// A `PyResult` containing a `Bound` object.
#[pyo3(signature = (message))]
#[inline(always)]
pub fn send(&self, py: Python, message: Message) -> PyResult<()> {
py.allow_threads(|| {
let sender = self.sender();
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let mut lock = sender.lock().await;
if let Some(send) = lock.as_mut() {
return send.send(message.0).await.map_err(wrap_rquest_error);
}
Ok(())
})
pyo3_async_runtimes::tokio::get_runtime()
.block_on(WebSocket::_send(self.sender(), message))
})
}

Expand All @@ -160,33 +147,15 @@ impl BlockingWebSocket {
///
/// A `PyResult` containing a `Bound` object.
#[pyo3(signature = (code=None, reason=None))]
#[inline(always)]
pub fn close(&self, py: Python, code: Option<u16>, reason: Option<String>) -> PyResult<()> {
py.allow_threads(|| {
let sender = self.sender();
let receiver = self.receiver();
pyo3_async_runtimes::tokio::get_runtime().block_on(async move {
let mut lock = receiver.lock().await;
drop(lock.take());
drop(lock);

let mut lock = sender.lock().await;
let send = lock.take();
drop(lock);

if let Some(mut send) = send {
if let Some(code) = code {
send.send(rquest::Message::Close {
code: rquest::CloseCode::from(code),
reason,
})
.await
.map_err(wrap_rquest_error)?;
}
return send.close().await.map_err(wrap_rquest_error);
}

Ok(())
})
pyo3_async_runtimes::tokio::get_runtime().block_on(WebSocket::_close(
self.receiver(),
self.sender(),
code,
reason,
))
})
}
}
Expand Down
4 changes: 4 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ pub fn py_stop_async_iteration_error() -> pyo3::PyErr {
PyStopAsyncIteration::new_err("The iterator is exhausted")
}

pub fn websocket_disconnect_error() -> pyo3::PyErr {
PyRuntimeError::new_err("The WebSocket has been disconnected")
}

pub fn stream_consumed_error() -> pyo3::PyErr {
BodyError::new_err("Stream is already consumed")
}
Expand Down

0 comments on commit d4cf2f0

Please sign in to comment.