Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove dependence on abseil-cpp #85

Merged
merged 11 commits into from
Sep 26, 2023
1 change: 0 additions & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# NOTE: there must be no spaces before the '-', so put the comma last.
InheritParentConfig: true
Checks: '
abseil-*,
bugprone-*,
-bugprone-easily-swappable-parameters,
clang-analyzer-*,
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Decrease the `MAX_RECURSION_DEPTH` to 2000 on Windows by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#85](/~https://github.com/metaopt/optree/pull/85).
- Bump `abseil-cpp` version to 20230802.1 by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#80](/~https://github.com/metaopt/optree/pull/80).

### Fixed
Expand All @@ -25,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Removed

-
- Remove dependence on `abseil-cpp` by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#85](/~https://github.com/metaopt/optree/pull/85).

------

Expand Down
25 changes: 0 additions & 25 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,9 @@ set(THIRD_PARTY_DIR "${CMAKE_SOURCE_DIR}/third-party")
if(NOT DEFINED PYBIND11_VERSION AND NOT "$ENV{PYBIND11_VERSION}" STREQUAL "")
set(PYBIND11_VERSION "$ENV{PYBIND11_VERSION}")
endif()
if(NOT DEFINED ABSEIL_CPP_VERSION AND NOT "$ENV{ABSEIL_CPP_VERSION}" STREQUAL "")
set(ABSEIL_CPP_VERSION "$ENV{ABSEIL_CPP_VERSION}")
endif()
if(NOT PYBIND11_VERSION)
set(PYBIND11_VERSION v2.11.1)
endif()
if(NOT ABSEIL_CPP_VERSION)
set(ABSEIL_CPP_VERSION 20230802.1)
endif()

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
Expand Down Expand Up @@ -189,24 +183,5 @@ else()
find_package(pybind11 CONFIG PATHS "${PYBIND11_CMAKE_DIR}")
endif()

# Include abseil-cpp
set(ABSL_PROPAGATE_CXX_STD ON)
set(ABSL_BUILD_TESTING OFF)
FetchContent_Declare(
abseilcpp
GIT_REPOSITORY /~https://github.com/abseil/abseil-cpp.git
GIT_TAG "${ABSEIL_CPP_VERSION}"
GIT_SHALLOW TRUE
SOURCE_DIR "${THIRD_PARTY_DIR}/abseil-cpp"
BINARY_DIR "${THIRD_PARTY_DIR}/.cmake/abseil-cpp/build"
STAMP_DIR "${THIRD_PARTY_DIR}/.cmake/abseil-cpp/stamp"
)
FetchContent_GetProperties(abseilcpp)

if(NOT abseilcpp_POPULATED)
message(STATUS "Populating Git repository abseil-cpp@${ABSEIL_CPP_VERSION} to third-party/abseil-cpp...")
FetchContent_MakeAvailable(abseilcpp)
endif()

include_directories("${CMAKE_SOURCE_DIR}")
add_subdirectory(src)
16 changes: 8 additions & 8 deletions include/exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ limitations under the License.

#pragma once

#include <absl/strings/str_format.h>

#include <sstream>
#include <stdexcept>
#include <string>

Expand All @@ -39,12 +38,13 @@ class InternalError : public std::logic_error {
public:
explicit InternalError(const std::string& msg) : std::logic_error(msg) {}
InternalError(const std::string& msg, const std::string& file, const size_t& lineno)
: InternalError(absl::StrFormat(
"%s (at file %s:%lu)\n\n%s",
msg,
file,
lineno,
"Please file a bug report at /~https://github.com/metaopt/optree/issues.")) {}
: InternalError([&msg, &file, &lineno]() {
std::stringstream ss;
ss << msg << " (at file " << file << ":" << lineno << ")";
ss << std::endl << std::endl;
ss << "Please file a bug report at /~https://github.com/metaopt/optree/issues.";
return ss.str();
}()) {}
};

} // namespace optree
Expand Down
20 changes: 12 additions & 8 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@ limitations under the License.

#pragma once

#include <absl/container/flat_hash_map.h>
#include <absl/hash/hash.h>
#include <pybind11/pybind11.h>

#include <memory>
#include <string>
#include <unordered_map>
#include <utility>

#include "include/exceptions.h"
Expand Down Expand Up @@ -109,6 +108,8 @@ class PyTreeTypeRegistry {
using is_transparent = void;
bool operator()(const py::object &a, const py::object &b) const;
bool operator()(const py::object &a, const py::handle &b) const;
bool operator()(const py::handle &a, const py::object &b) const;
bool operator()(const py::handle &a, const py::handle &b) const;
};

class NamedTypeHash {
Expand All @@ -124,14 +125,17 @@ class PyTreeTypeRegistry {
const std::pair<std::string, py::object> &b) const;
bool operator()(const std::pair<std::string, py::object> &a,
const std::pair<std::string, py::handle> &b) const;
bool operator()(const std::pair<std::string, py::handle> &a,
const std::pair<std::string, py::object> &b) const;
bool operator()(const std::pair<std::string, py::handle> &a,
const std::pair<std::string, py::handle> &b) const;
};

absl::flat_hash_map<py::object, std::unique_ptr<Registration>, TypeHash, TypeEq>
m_registrations;
absl::flat_hash_map<std::pair<std::string, py::object>,
std::unique_ptr<Registration>,
NamedTypeHash,
NamedTypeEq>
std::unordered_map<py::object, std::unique_ptr<Registration>, TypeHash, TypeEq> m_registrations;
std::unordered_map<std::pair<std::string, py::object>,
std::unique_ptr<Registration>,
NamedTypeHash,
NamedTypeEq>
m_named_registrations;
};

Expand Down
113 changes: 28 additions & 85 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@ limitations under the License.

#pragma once

#include <absl/container/flat_hash_set.h>
#include <absl/container/inlined_vector.h>
#include <absl/hash/hash.h>
#include <pybind11/pybind11.h>

#include <memory>
Expand All @@ -28,6 +25,7 @@ limitations under the License.
#include <string>
#include <thread> // NOLINT[build/c++11]
#include <tuple>
#include <unordered_set>
#include <utility>
#include <vector>

Expand All @@ -39,7 +37,7 @@ namespace optree {

// The maximum depth of a pytree.
#ifdef _WIN32
constexpr ssize_t MAX_RECURSION_DEPTH = 2500;
constexpr ssize_t MAX_RECURSION_DEPTH = 2000;
#else
constexpr ssize_t MAX_RECURSION_DEPTH = 5000;
#endif
Expand Down Expand Up @@ -159,77 +157,7 @@ class PyTreeSpec {
inline bool operator>=(const PyTreeSpec &other) const { return IsSuffix(other, false); }

// Return the hash value of the PyTreeSpec.
template <typename H>
friend H AbslHashValue(H h, const Node &n) {
ssize_t data_hash = 0;
switch (n.kind) {
case PyTreeKind::Custom:
// We don't hash node_data custom node types since they may not hashable.
break;
case PyTreeKind::Leaf:
case PyTreeKind::None:
case PyTreeKind::Tuple:
case PyTreeKind::List:
case PyTreeKind::NamedTuple:
case PyTreeKind::Deque:
case PyTreeKind::StructSequence:
data_hash = py::hash(n.node_data ? n.node_data : py::none());
break;
case PyTreeKind::Dict:
case PyTreeKind::OrderedDict:
case PyTreeKind::DefaultDict: {
py::list keys;
if (n.kind == PyTreeKind::DefaultDict) [[unlikely]] {
EXPECT_EQ(
GET_SIZE<py::tuple>(n.node_data), 2, "Number of auxiliary data mismatch.");
py::object default_factory = GET_ITEM_BORROW<py::tuple>(n.node_data, 0);
keys = py::reinterpret_borrow<py::list>(
GET_ITEM_BORROW<py::tuple>(n.node_data, 1));
EXPECT_EQ(GET_SIZE<py::list>(keys),
n.arity,
"Number of keys and entries does not match.");
data_hash = py::hash(default_factory);
} else [[likely]] {
EXPECT_EQ(GET_SIZE<py::list>(n.node_data),
n.arity,
"Number of keys and entries does not match.");
keys = py::reinterpret_borrow<py::list>(n.node_data);
}
for (const py::handle &&key : keys) {
data_hash = py::ssize_t_cast(absl::HashOf(data_hash, py::hash(key)));
}
break;
}
default:
INTERNAL_ERROR();
}

return H::combine(
std::move(h), n.kind, n.arity, n.custom, n.num_leaves, n.num_nodes, data_hash);
}

template <typename H>
friend H AbslHashValueImpl(H h, const PyTreeSpec &t) {
return H::combine(std::move(h), t.m_traversal, t.m_none_is_leaf, t.m_namespace);
}

template <typename H>
friend H AbslHashValue(H h, const PyTreeSpec &t) {
std::pair<const PyTreeSpec *, std::thread::id> indent{&t, std::this_thread::get_id()};
if (sm_hash_running.contains(indent)) {
return h;
}

sm_hash_running.insert(indent);
try {
H hash = AbslHashValueImpl(std::move(h), t);
sm_hash_running.erase(indent);
return hash;
} catch (...) {
sm_hash_running.erase(indent);
std::rethrow_exception(std::current_exception());
}
}
[[nodiscard]] ssize_t HashValue() const;

// Return a string representation of the PyTreeSpec.
[[nodiscard]] std::string ToString() const;
Expand Down Expand Up @@ -280,24 +208,18 @@ class PyTreeSpec {

// Nodes, in a post-order traversal. We use an ordered traversal to minimize allocations, and
// post-order corresponds to the order we need to rebuild the tree structure.
absl::InlinedVector<Node, 1> m_traversal;
std::vector<Node> m_traversal;

// Whether to treat `None` as a leaf. If false, `None` is a non-leaf node with arity 0.
bool m_none_is_leaf = false;

// The registry namespace used to resolve the custom pytree node types.
std::string m_namespace;

// A set of (treespec, thread_id) pairs that are currently being represented as strings.
inline static absl::flat_hash_set<std::pair<const PyTreeSpec *, std::thread::id>>
sm_repr_running{};

// A set of (treespec, thread_id) pairs that are currently being hashed.
inline static absl::flat_hash_set<std::pair<const PyTreeSpec *, std::thread::id>>
sm_hash_running{};

// Helper that manufactures an instance of a node given its children.
static py::object MakeNode(const Node &node, const absl::Span<py::object> &children);
static py::object MakeNode(const Node &node,
const py::object *children,
const size_t &num_children);

// Compute the node kind of a given Python object.
template <bool NoneIsLeaf>
Expand Down Expand Up @@ -335,9 +257,30 @@ class PyTreeSpec {
const ssize_t &pos,
const ssize_t &depth) const;

// Get the hash value of the node.
static void HashCombineNode(ssize_t &seed, const Node &node); // NOLINT[runtime/references]

[[nodiscard]] ssize_t HashValueImpl() const;

[[nodiscard]] std::string ToStringImpl() const;

static std::unique_ptr<PyTreeSpec> FromPicklableImpl(const py::object &picklable);

class ThreadIndentTypeHash {
public:
using is_transparent = void;
size_t operator()(const std::pair<const PyTreeSpec *, std::thread::id> &p) const;
};

// A set of (treespec, thread_id) pairs that are currently being represented as strings.
inline static std::unordered_set<std::pair<const PyTreeSpec *, std::thread::id>,
ThreadIndentTypeHash>
sm_repr_running{};

// A set of (treespec, thread_id) pairs that are currently being hashed.
inline static std::unordered_set<std::pair<const PyTreeSpec *, std::thread::id>,
ThreadIndentTypeHash>
sm_hash_running{};
};

} // namespace optree
Loading