diff --git a/Cargo.toml b/Cargo.toml index 1b6416f07d7..c7b25ff82b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,9 +31,6 @@ unindent = { version = "0.2.1", optional = true } # support crate for multiple-pymethods feature inventory = { version = "0.3.0", optional = true } -# coroutine implementation -futures-util = "0.3" - # crate integrations that can be added using the eponymous features anyhow = { version = "1.0", optional = true } chrono = { version = "0.4.25", default-features = false, optional = true } diff --git a/src/coroutine.rs b/src/coroutine.rs index c4c7bbf29cd..6380b4e0a1f 100644 --- a/src/coroutine.rs +++ b/src/coroutine.rs @@ -1,19 +1,17 @@ //! Python coroutine implementation, used notably when wrapping `async fn` //! with `#[pyfunction]`/`#[pymethods]`. use std::{ - any::Any, future::Future, panic, pin::Pin, sync::Arc, - task::{Context, Poll}, + task::{Context, Poll, Waker}, }; -use futures_util::FutureExt; use pyo3_macros::{pyclass, pymethods}; use crate::{ - coroutine::waker::AsyncioWaker, + coroutine::{cancel::ThrowCallback, waker::AsyncioWaker}, exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration}, panic::PanicException, pyclass::IterNextOutput, @@ -24,20 +22,17 @@ use crate::{ pub(crate) mod cancel; mod waker; -use crate::coroutine::cancel::ThrowCallback; pub use cancel::CancelHandle; const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine"; -type FutureOutput = Result, Box>; - /// Python coroutine wrapping a [`Future`]. #[pyclass(crate = "crate")] pub struct Coroutine { name: Option>, qualname_prefix: Option<&'static str>, throw_callback: Option, - future: Option + Send>>>, + future: Option> + Send>>>, waker: Option>, } @@ -68,7 +63,7 @@ impl Coroutine { name, qualname_prefix, throw_callback, - future: Some(Box::pin(panic::AssertUnwindSafe(wrap).catch_unwind())), + future: Some(Box::pin(wrap)), waker: None, } } @@ -98,14 +93,20 @@ impl Coroutine { } else { self.waker = Some(Arc::new(AsyncioWaker::new())); } - let waker = futures_util::task::waker(self.waker.clone().unwrap()); + let waker = Waker::from(self.waker.clone().unwrap()); // poll the Rust future and forward its results if ready - if let Poll::Ready(res) = future_rs.as_mut().poll(&mut Context::from_waker(&waker)) { - self.close(); - return match res { - Ok(res) => Ok(IterNextOutput::Return(res?)), - Err(err) => Err(PanicException::from_panic_payload(err)), - }; + // polling is UnwindSafe because the future is dropped in case of panic + let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker)); + match panic::catch_unwind(panic::AssertUnwindSafe(poll)) { + Ok(Poll::Ready(res)) => { + self.close(); + return Ok(IterNextOutput::Return(res?)); + } + Err(err) => { + self.close(); + return Err(PanicException::from_panic_payload(err)); + } + _ => {} } // otherwise, initialize the waker `asyncio.Future` if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? { @@ -113,7 +114,7 @@ impl Coroutine { // and will yield itself if its result has not been set in polling above if let Some(future) = PyIterator::from_object(future).unwrap().next() { // future has not been leaked into Python for now, and Rust code can only call - // `set_result(None)` in `ArcWake` implementation, so it's safe to unwrap + // `set_result(None)` in `Wake` implementation, so it's safe to unwrap return Ok(IterNextOutput::Yield(future.unwrap().into())); } } diff --git a/src/coroutine/waker.rs b/src/coroutine/waker.rs index 7ed4103fbb7..8a1166ce3fb 100644 --- a/src/coroutine/waker.rs +++ b/src/coroutine/waker.rs @@ -1,11 +1,11 @@ use crate::sync::GILOnceCell; use crate::types::PyCFunction; use crate::{intern, wrap_pyfunction, Py, PyAny, PyObject, PyResult, Python}; -use futures_util::task::ArcWake; use pyo3_macros::pyfunction; use std::sync::Arc; +use std::task::Wake; -/// Lazy `asyncio.Future` wrapper, implementing [`ArcWake`] by calling `Future.set_result`. +/// Lazy `asyncio.Future` wrapper, implementing [`Wake`] by calling `Future.set_result`. /// /// asyncio future is let uninitialized until [`initialize_future`][1] is called. /// If [`wake`][2] is called before future initialization (during Rust future polling), @@ -31,10 +31,14 @@ impl AsyncioWaker { } } -impl ArcWake for AsyncioWaker { - fn wake_by_ref(arc_self: &Arc) { +impl Wake for AsyncioWaker { + fn wake(self: Arc) { + self.wake_by_ref() + } + + fn wake_by_ref(self: &Arc) { Python::with_gil(|gil| { - if let Some(loop_and_future) = arc_self.0.get_or_init(gil, || None) { + if let Some(loop_and_future) = self.0.get_or_init(gil, || None) { loop_and_future .set_result(gil) .expect("unexpected error in coroutine waker"); diff --git a/tests/test_coroutine.rs b/tests/test_coroutine.rs index 5d4f04d63c3..cf975423c25 100644 --- a/tests/test_coroutine.rs +++ b/tests/test_coroutine.rs @@ -1,12 +1,14 @@ #![cfg(feature = "macros")] #![cfg(not(target_arch = "wasm32"))] -use std::ops::Deref; -use std::{task::Poll, thread, time::Duration}; +use std::{ops::Deref, task::Poll, thread, time::Duration}; use futures::{channel::oneshot, future::poll_fn, FutureExt}; -use pyo3::coroutine::CancelHandle; -use pyo3::types::{IntoPyDict, PyType}; -use pyo3::{prelude::*, py_run}; +use pyo3::{ + coroutine::CancelHandle, + prelude::*, + py_run, + types::{IntoPyDict, PyType}, +}; #[path = "../src/tests/common.rs"] mod common; @@ -119,7 +121,7 @@ fn cancelled_coroutine() { let test = r#" import asyncio async def main(): - task = asyncio.create_task(sleep(1)) + task = asyncio.create_task(sleep(999)) await asyncio.sleep(0) task.cancel() await task @@ -155,7 +157,7 @@ fn coroutine_cancel_handle() { let test = r#" import asyncio; async def main(): - task = asyncio.create_task(cancellable_sleep(1)) + task = asyncio.create_task(cancellable_sleep(999)) await asyncio.sleep(0) task.cancel() return await task @@ -203,3 +205,32 @@ fn coroutine_is_cancelled() { .unwrap(); }) } + +#[test] +fn coroutine_panic() { + #[pyfunction] + async fn panic() { + panic!("test panic"); + } + Python::with_gil(|gil| { + let panic = wrap_pyfunction!(panic, gil).unwrap(); + let test = r#" + import asyncio + coro = panic() + try: + asyncio.run(coro) + except BaseException as err: + assert type(err).__name__ == "PanicException" + assert str(err) == "test panic" + else: + assert False + try: + coro.send(None) + except RuntimeError as err: + assert str(err) == "cannot reuse already awaited coroutine" + else: + assert False; + "#; + py_run!(gil, panic, &handle_windows(test)); + }) +}