Skip to content

Commit

Permalink
Merge pull request #1591 from davidhewitt/inherit-exceptions
Browse files Browse the repository at this point in the history
pyclass: support extending Exception types
  • Loading branch information
davidhewitt authored May 4, 2021
2 parents 05db24c + d81abe8 commit 350e7b2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Add FFI definitions from `cpython/import.h`.[#1475](/~https://github.com/PyO3/pyo3/pull/1475)
- Add tuple and unit struct support for `#[pyclass]` macro. [#1504](/~https://github.com/PyO3/pyo3/pull/1504)
- Add FFI definition `PyDateTime_TimeZone_UTC`. [#1572](/~https://github.com/PyO3/pyo3/pull/1572)
- Add support for `#[pyclass(extends=Exception)]`. [#1591](/~https://github.com/PyO3/pyo3/pull/1591)

### Changed
- Change `PyTimeAcces::get_fold()` to return a `bool` instead of a `u8`. [#1397](/~https://github.com/PyO3/pyo3/pull/1397)
Expand Down
46 changes: 45 additions & 1 deletion tests/test_inheritance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ except Exception as e:
#[cfg(not(Py_LIMITED_API))]
mod inheriting_native_type {
use super::*;
use pyo3::types::{PyDict, PySet};
use pyo3::exceptions::PyException;
use pyo3::types::{IntoPyDict, PyDict, PySet};

#[pyclass(extends=PySet)]
#[derive(Debug)]
Expand Down Expand Up @@ -208,6 +209,49 @@ mod inheriting_native_type {
r#"dict_sub[0] = 1; assert dict_sub[0] == 1; assert dict_sub._name == "Hello :)""#
);
}

#[pyclass(extends=PyException)]
struct CustomException {
#[pyo3(get)]
context: &'static str,
}

#[pymethods]
impl CustomException {
#[new]
fn new() -> Self {
CustomException {
context: "Hello :)",
}
}
}

#[test]
fn custom_exception() {
Python::with_gil(|py| {
let cls = py.get_type::<CustomException>();
let dict = [("cls", cls)].into_py_dict(py);
let res = py.run(
"e = cls('hello'); assert str(e) == 'hello'; assert e.context == 'Hello :)'; raise e",
None,
Some(dict)
);
let err = res.unwrap_err();
assert!(err.matches(py, cls), "{}", err);

// catching the exception in Python also works:
py_run!(
py,
cls,
r#"
try:
raise cls("foo")
except cls:
pass
"#
)
})
}
}

#[pyclass(subclass)]
Expand Down

0 comments on commit 350e7b2

Please sign in to comment.