Skip to content

Commit

Permalink
refactor: use more appropriate Exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Mar 4, 2023
1 parent fc73fef commit 9cdc9c5
Show file tree
Hide file tree
Showing 14 changed files with 66 additions and 95 deletions.
3 changes: 0 additions & 3 deletions include/exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ limitations under the License.
================================================================================
*/

// Caution: this code uses exceptions. The exception use is local to the binding
// code and the idiomatic way to emit Python exceptions.

#pragma once

#include <absl/strings/str_format.h>
Expand Down
3 changes: 0 additions & 3 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ limitations under the License.
================================================================================
*/

// Caution: this code uses exceptions. The exception use is local to the binding
// code and the idiomatic way to emit Python exceptions.

#pragma once

#include <absl/container/flat_hash_map.h>
Expand Down
3 changes: 0 additions & 3 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ limitations under the License.
================================================================================
*/

// Caution: this code uses exceptions. The exception use is local to the binding
// code and the idiomatic way to emit Python exceptions.

#pragma once

#include <absl/container/inlined_vector.h>
Expand Down
33 changes: 15 additions & 18 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ limitations under the License.
================================================================================
*/

// Caution: this code uses exceptions. The exception use is local to the binding
// code and the idiomatic way to emit Python exceptions.

#pragma once

#include <Python.h>
Expand Down Expand Up @@ -254,50 +251,50 @@ inline void SET_ITEM<py::list>(const py::handle& container,
template <typename PyType>
inline void AssertExact(const py::handle& object) {
if (!py::isinstance<PyType>(object)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"Expected an instance of %s, got %s.", typeid(PyType).name(), py::repr(object)));
}
}
template <>
inline void AssertExact<py::list>(const py::handle& object) {
if (!PyList_CheckExact(object.ptr())) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("Expected an instance of list, got %s.", py::repr(object)));
}
}
template <>
inline void AssertExact<py::tuple>(const py::handle& object) {
if (!PyTuple_CheckExact(object.ptr())) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("Expected an instance of tuple, got %s.", py::repr(object)));
}
}
template <>
inline void AssertExact<py::dict>(const py::handle& object) {
if (!PyDict_CheckExact(object.ptr())) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("Expected an instance of dict, got %s.", py::repr(object)));
}
}

inline void AssertExactOrderedDict(const py::handle& object) {
if (!object.get_type().is(PyOrderedDictTypeObject)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"Expected an instance of collections.OrderedDict, got %s.", py::repr(object)));
}
}

inline void AssertExactDefaultDict(const py::handle& object) {
if (!object.get_type().is(PyDefaultDictTypeObject)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"Expected an instance of collections.defaultdict, got %s.", py::repr(object)));
}
}

inline void AssertExactDeque(const py::handle& object) {
if (!object.get_type().is(PyDequeTypeObject)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
"Expected an instance of collections.deque, got %s.", py::repr(object)));
throw py::value_error(absl::StrFormat("Expected an instance of collections.deque, got %s.",
py::repr(object)));
}
}

Expand Down Expand Up @@ -334,7 +331,7 @@ inline bool IsNamedTuple(const py::handle& object) {
}
inline void AssertExactNamedTuple(const py::handle& object) {
if (!IsNamedTupleInstance(object)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"Expected an instance of collections.namedtuple, got %s.", py::repr(object)));
}
}
Expand All @@ -343,13 +340,13 @@ inline py::tuple NamedTupleGetFields(const py::handle& object) {
if (PyType_Check(object.ptr())) {
type = object;
if (!IsNamedTupleClass(type)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
"Expected a collections.namedtuple type, got %s.", py::repr(object)));
throw py::type_error(absl::StrFormat("Expected a collections.namedtuple type, got %s.",
py::repr(object)));
}
} else {
type = object.get_type();
if (!IsNamedTupleClass(type)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::type_error(absl::StrFormat(
"Expected an instance of collections.namedtuple type, got %s.", py::repr(object)));
}
}
Expand Down Expand Up @@ -391,7 +388,7 @@ inline bool IsStructSequence(const py::handle& object) {
}
inline void AssertExactStructSequence(const py::handle& object) {
if (!IsStructSequenceInstance(object)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"Expected an instance of StructSequence type, got %s.", py::repr(object)));
}
}
Expand All @@ -400,13 +397,13 @@ inline py::tuple StructSequenceGetFields(const py::handle& object) {
if (PyType_Check(object.ptr())) {
type = object;
if (!IsStructSequenceClass(type)) [[unlikely]] {
throw std::invalid_argument(
throw py::type_error(
absl::StrFormat("Expected a StructSequence type, got %s.", py::repr(object)));
}
} else {
type = object.get_type();
if (!IsStructSequenceClass(type)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::type_error(absl::StrFormat(
"Expected an instance of StructSequence type, got %s.", py::repr(object)));
}
}
Expand Down
4 changes: 2 additions & 2 deletions optree/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]:
if isinstance(obj, type):
cls = obj
if not is_namedtuple_class(cls):
raise ValueError(f'Expected a collections.namedtuple type, got {cls!r}.')
raise TypeError(f'Expected a collections.namedtuple type, got {cls!r}.')
else:
cls = type(obj)
if not is_namedtuple_class(cls):
raise ValueError(f'Expected an instance of collections.namedtuple type, got {obj!r}.')
raise TypeError(f'Expected an instance of collections.namedtuple type, got {obj!r}.')
return cls._fields # type: ignore[attr-defined]


Expand Down
5 changes: 3 additions & 2 deletions optree/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MetaOPT Team. All Rights Reserved.
# Copyright 2022-2023 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -23,7 +23,8 @@ def safe_zip(*args: Sequence[Any]) -> list[tuple[Any, ...]]:
"""Strict zip that requires all arguments to be the same length."""
n = len(args[0])
for arg in args[1:]:
assert len(arg) == n, f'length mismatch: {list(map(len, args))}'
if len(arg) != n:
raise ValueError(f'length mismatch: {list(map(len, args))}')
return list(zip(*args))


Expand Down
3 changes: 0 additions & 3 deletions src/optree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ limitations under the License.
================================================================================
*/

// Caution: this code uses exceptions. The exception use is local to the binding
// code and the idiomatic way to emit Python exceptions.

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

Expand Down
9 changes: 3 additions & 6 deletions src/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ limitations under the License.
================================================================================
*/

// Caution: this code uses exceptions. The exception use is local to the binding
// code and the idiomatic way to emit Python exceptions.

#include "include/registry.h"

namespace optree {
Expand Down Expand Up @@ -71,7 +68,7 @@ template <bool NoneIsLeaf>
registration->from_iterable = py::reinterpret_borrow<py::function>(from_iterable);
if (registry_namespace.empty()) [[unlikely]] {
if (!registry->m_registrations.emplace(cls, std::move(registration)).second) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"PyTree type %s is already registered in the global namespace.", py::repr(cls)));
}
if (IsNamedTupleClass(cls)) [[unlikely]] {
Expand All @@ -95,13 +92,13 @@ template <bool NoneIsLeaf>
}
} else [[likely]] {
if (registry->m_registrations.find(cls) != registry->m_registrations.end()) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"PyTree type %s is already registered in the global namespace.", py::repr(cls)));
}
if (!registry->m_named_registrations
.emplace(std::make_pair(registry_namespace, cls), std::move(registration))
.second) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("PyTree type %s is already registered in namespace %s.",
py::repr(cls),
py::repr(py::str(registry_namespace))));
Expand Down
33 changes: 15 additions & 18 deletions src/treespec/flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ limitations under the License.
================================================================================
*/

// Caution: this code uses exceptions. The exception use is local to the binding
// code and the idiomatic way to emit Python exceptions.

#include "include/treespec.h"

namespace optree {
Expand Down Expand Up @@ -409,7 +406,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
ssize_t leaf = num_leaves - 1;
while (!agenda.empty()) {
if (it == m_traversal.rend()) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("Tree structures did not match; expected: %s, got: %s.",
ToString(),
py::repr(full_tree)));
Expand All @@ -433,7 +430,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
AssertExact<py::tuple>(object);
auto tuple = py::reinterpret_borrow<py::tuple>(object);
if (GET_SIZE<py::tuple>(tuple) != node.arity) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("tuple arity mismatch; expected: %ld, got: %ld; tuple: %s.",
node.arity,
GET_SIZE<py::tuple>(tuple),
Expand All @@ -449,7 +446,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
AssertExact<py::list>(object);
auto list = py::reinterpret_borrow<py::list>(object);
if (GET_SIZE<py::list>(list) != node.arity) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("list arity mismatch; expected: %ld, got: %ld; list: %s.",
node.arity,
GET_SIZE<py::list>(list),
Expand Down Expand Up @@ -480,7 +477,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
py::object expected_default_factory =
GET_ITEM_BORROW<py::tuple>(node.node_data, 0);
if (default_factory.not_equal(expected_default_factory)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"defaultdict factory mismatch; "
"expected factory: %s, got factory: %s; defaultdict: %s.",
py::repr(expected_default_factory),
Expand All @@ -492,7 +489,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] {
py::list keys = DictKeys(dict);
if (keys.not_equal(expected_keys)) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("OrderedDict key mismatch; "
"expected key(s): %s, got key(s): %s; OrderedDict: %s.",
py::repr(expected_keys),
Expand All @@ -512,7 +509,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
key_difference +=
absl::StrFormat(", extra key(s): %s", py::repr(extra_keys));
}
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("%s key mismatch; "
"expected key(s): %s, got key(s): %s%s; %s: %s.",
cls_name,
Expand All @@ -532,14 +529,14 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
AssertExactNamedTuple(object);
auto tuple = py::reinterpret_borrow<py::tuple>(object);
if (GET_SIZE<py::tuple>(tuple) != node.arity) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"namedtuple arity mismatch; expected: %ld, got: %ld; tuple: %s.",
node.arity,
GET_SIZE<py::tuple>(tuple),
py::repr(object)));
}
if (object.get_type().not_equal(node.node_data)) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"namedtuple type mismatch; expected type: %s, got type: %s; tuple: %s.",
py::repr(node.node_data),
py::repr(object.get_type()),
Expand All @@ -555,7 +552,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
AssertExactDeque(object);
auto list = py::cast<py::list>(object);
if (GET_SIZE<py::list>(list) != node.arity) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("deque arity mismatch; expected: %ld, got: %ld; deque: %s.",
node.arity,
GET_SIZE<py::list>(list),
Expand All @@ -571,14 +568,14 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
AssertExactStructSequence(object);
auto tuple = py::reinterpret_borrow<py::tuple>(object);
if (GET_SIZE<py::tuple>(tuple) != node.arity) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"StructSequence arity mismatch; expected: %ld, got: %ld; tuple: %s.",
node.arity,
GET_SIZE<py::tuple>(tuple),
py::repr(object)));
}
if (object.get_type().not_equal(node.node_data)) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("StructSequence type mismatch; "
"expected type: %s, got type: %s; tuple: %s.",
py::repr(node.node_data),
Expand All @@ -601,7 +598,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
PyTreeTypeRegistry::Lookup<NONE_IS_NODE>(object.get_type(), m_namespace);
}
if (registration != node.custom) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"Custom node type mismatch; expected type: %s, got type: %s; value: %s.",
py::repr(node.custom->type),
py::repr(object.get_type()),
Expand All @@ -617,7 +614,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
num_out));
}
if (node.node_data.not_equal(GET_ITEM_BORROW<py::tuple>(out, 1))) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"Mismatch custom node data; expected: %s, got: %s; value: %s.",
py::repr(node.node_data),
py::repr(GET_ITEM_BORROW<py::tuple>(out, 1)),
Expand All @@ -630,7 +627,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
agenda.emplace_back(py::reinterpret_borrow<py::object>(child));
}
if (arity != node.arity) [[unlikely]] {
throw std::invalid_argument(absl::StrFormat(
throw py::value_error(absl::StrFormat(
"Custom type arity mismatch; expected: %ld, got: %ld; value: %s.",
node.arity,
arity,
Expand All @@ -644,7 +641,7 @@ py::list PyTreeSpec::FlattenUpToImpl(const py::handle& full_tree) const {
}
}
if (it != m_traversal.rend() || leaf != -1) [[unlikely]] {
throw std::invalid_argument(
throw py::value_error(
absl::StrFormat("Tree structures did not match; expected: %s, got: %s.",
ToString(),
py::repr(full_tree)));
Expand Down
7 changes: 2 additions & 5 deletions src/treespec/traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ limitations under the License.
================================================================================
*/

// Caution: this code uses exceptions. The exception use is local to the binding
// code and the idiomatic way to emit Python exceptions.

#include "include/treespec.h"

namespace optree {
Expand All @@ -32,7 +29,7 @@ py::object PyTreeSpec::Walk(const py::function& f_node,
switch (node.kind) {
case PyTreeKind::Leaf: {
if (it == leaves.end()) [[unlikely]] {
throw std::invalid_argument("Too few leaves for PyTreeSpec.");
throw py::value_error("Too few leaves for PyTreeSpec.");
}

auto leaf = py::reinterpret_borrow<py::object>(*it);
Expand Down Expand Up @@ -69,7 +66,7 @@ py::object PyTreeSpec::Walk(const py::function& f_node,
}
}
if (it != leaves.end()) [[unlikely]] {
throw std::invalid_argument("Too many leaves for PyTreeSpec.");
throw py::value_error("Too many leaves for PyTreeSpec.");
}

EXPECT_EQ(agenda.size(), 1, "PyTreeSpec traversal did not yield a singleton.");
Expand Down
Loading

0 comments on commit 9cdc9c5

Please sign in to comment.