Skip to content

Commit

Permalink
WIP: Use a capsule-based API with a stable ABI for global, cross-exte…
Browse files Browse the repository at this point in the history
…nsion functionality. [skip ci]
  • Loading branch information
adamreichold committed Mar 30, 2023
1 parent 90d50da commit 6d854fe
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 9 deletions.
85 changes: 85 additions & 0 deletions src/global_api.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//! TODO
use std::{ffi::CString, mem::forget};

use crate::{
conversion::PyTryInto,
exceptions::PyTypeError,
ffi,
sync::GILOnceCell,
type_object::PyTypeInfo,
types::{PyCapsule, PyDict, PyModule},
Py, PyResult, Python,
};

#[repr(C)]
pub(crate) struct GlobalApi {
version: u64,
pub(crate) create_panic_exception:
unsafe extern "C" fn(msg_ptr: *const u8, msg_len: usize) -> *mut ffi::PyObject,
}

pub(crate) fn ensure_global_api(py: Python<'_>) -> PyResult<&GlobalApi> {
let api = GLOBAL_API.0.get_or_try_init(py, || init_global_api(py))?;

// SAFETY: We inserted the capsule if it was missing
// and verified that it contains a compatible version.
Ok(unsafe { &**api })
}

struct GlobalApiPtr(GILOnceCell<*const GlobalApi>);

unsafe impl Send for GlobalApiPtr {}

unsafe impl Sync for GlobalApiPtr {}

static GLOBAL_API: GlobalApiPtr = GlobalApiPtr(GILOnceCell::new());

#[cold]
fn init_global_api(py: Python<'_>) -> PyResult<*const GlobalApi> {
let module = match PyModule::import(py, "pyo3") {
Ok(module) => module,
Err(_err) => {
let module = PyModule::new(py, "pyo3")?;

module.add(
"PanicException",
crate::panic::PanicException::type_object(py),
)?;

let sys = PyModule::import(py, "sys")?;
let modules: &PyDict = sys.getattr("modules")?.downcast()?;
modules.set_item("pyo3", module)?;

module
}
};

let capsule: &PyCapsule = match module.getattr("_GLOBAL_API") {
Ok(capsule) => PyTryInto::try_into(capsule)?,
Err(_err) => {
let api = GlobalApi {
version: 1,
create_panic_exception: crate::panic::create_panic_exception,
};

let capsule = PyCapsule::new(py, api, Some(CString::new("_GLOBAL_API").unwrap()))?;
module.setattr("_GLOBAL_API", capsule)?;
capsule
}
};

// SAFETY: All versions of the global API start with a version field.
let version = unsafe { *(capsule.pointer() as *mut u64) };
if version < 1 {
return Err(PyTypeError::new_err(format!(
"Version {} of global API is not supported by this version of PyO3",
version
)));
}

// Intentionally leak a reference to the capsule so we can safely cache a pointer into its interior.
forget(Py::<PyCapsule>::from(capsule));

Ok(capsule.pointer() as *const GlobalApi)
}
4 changes: 2 additions & 2 deletions src/impl_/trampoline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ where
let py_err = match panic_result {
Ok(Ok(value)) => return value,
Ok(Err(py_err)) => py_err,
Err(payload) => PanicException::from_panic_payload(payload),
Err(payload) => PanicException::from_panic_payload(py, payload),
};
py_err.restore(py);
R::ERR_VALUE
Expand All @@ -245,7 +245,7 @@ where
let pool = GILPool::new();
let py = pool.python();
if let Err(py_err) = panic::catch_unwind(move || body(py))
.unwrap_or_else(|payload| Err(PanicException::from_panic_payload(payload)))
.unwrap_or_else(|payload| Err(PanicException::from_panic_payload(py, payload)))
{
py_err.write_unraisable(py, py.from_borrowed_ptr_or_opt(ctx));
}
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ pub mod marker;
pub mod marshal;
#[macro_use]
pub mod sync;
mod global_api;
pub mod panic;
pub mod prelude;
pub mod pycell;
Expand Down
43 changes: 36 additions & 7 deletions src/panic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
//! Helper to convert Rust panics to Python exceptions.
use crate::conversion::{FromPyPointer, IntoPyPointer};
use crate::exceptions::PyBaseException;
use crate::PyErr;
use crate::ffi;
use crate::global_api::ensure_global_api;
use crate::{PyAny, Python};
use std::any::Any;
use std::slice;
use std::str;

pyo3_exception!(
"
Expand All @@ -20,13 +25,37 @@ impl PanicException {
///
/// Attempts to format the error in the same way panic does.
#[cold]
pub(crate) fn from_panic_payload(payload: Box<dyn Any + Send + 'static>) -> PyErr {
if let Some(string) = payload.downcast_ref::<String>() {
Self::new_err((string.clone(),))
pub(crate) fn from_panic_payload<'py>(
py: Python<'py>,
payload: Box<dyn Any + Send + 'static>,
) -> &'py PyAny {
let msg = if let Some(string) = payload.downcast_ref::<String>() {
string.clone()
} else if let Some(s) = payload.downcast_ref::<&str>() {
Self::new_err((s.to_string(),))
s.to_string()
} else {
Self::new_err(("panic from Rust code",))
}
"panic from Rust code".to_owned()
};

let api = match ensure_global_api(py) {
Ok(api) => api,
// The global API is unavailable, hence we fall back to our own `PanicException`.
Err(err) => return PanicException::new_err((msg,)).into_value(py).into_ref(py),
};

let err = (api.create_panic_exception)(msg.as_ptr(), msg.len());

PyAny::from_owned_ptr(py, err)
}
}

pub(crate) unsafe extern "C" fn create_panic_exception(
msg_ptr: *const u8,
msg_len: usize,
) -> *mut ffi::PyObject {
let msg = str::from_utf8_unchecked(slice::from_raw_parts(msg_ptr, msg_len));

let err = PanicException::new_err((msg,));

err.into_value(Python::assume_gil_acquired()).into_ptr()
}

0 comments on commit 6d854fe

Please sign in to comment.