Skip to content

Commit

Permalink
feat(ops): add accessor APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Apr 2, 2024
1 parent 50eaaa5 commit 1d03768
Show file tree
Hide file tree
Showing 15 changed files with 679 additions and 295 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add typed path APIs `tree_flatten_with_typed_path` and `PyTreeSpec.typed_paths` by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#108](/~https://github.com/metaopt/optree/pull/108).
- Add accessor APIs `tree_flatten_with_accessor` and `PyTreeSpec.accessors` by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#108](/~https://github.com/metaopt/optree/pull/108).

### Changed

Expand Down
12 changes: 6 additions & 6 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ Tree Manipulation Functions

tree_flatten
tree_flatten_with_path
tree_flatten_with_typed_path
tree_flatten_with_accessor
tree_unflatten
tree_iter
tree_leaves
tree_structure
tree_paths
tree_typed_paths
tree_accessors
tree_is_leaf
all_leaves
tree_map
Expand All @@ -53,13 +53,13 @@ Tree Manipulation Functions

.. autofunction:: tree_flatten
.. autofunction:: tree_flatten_with_path
.. autofunction:: tree_flatten_with_typed_path
.. autofunction:: tree_flatten_with_accessor
.. autofunction:: tree_unflatten
.. autofunction:: tree_iter
.. autofunction:: tree_leaves
.. autofunction:: tree_structure
.. autofunction:: tree_paths
.. autofunction:: tree_typed_paths
.. autofunction:: tree_accessors
.. autofunction:: tree_is_leaf
.. autofunction:: all_leaves
.. autofunction:: tree_map
Expand Down Expand Up @@ -108,7 +108,7 @@ PyTreeSpec Functions
.. autosummary::

treespec_paths
treespec_typed_paths
treespec_accessors
treespec_entries
treespec_entry
treespec_children
Expand All @@ -130,7 +130,7 @@ PyTreeSpec Functions
treespec_from_collection

.. autofunction:: treespec_paths
.. autofunction:: treespec_typed_paths
.. autofunction:: treespec_accessors
.. autofunction:: treespec_entries
.. autofunction:: treespec_entry
.. autofunction:: treespec_children
Expand Down
4 changes: 4 additions & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,7 @@ getattr
setattr
delattr
typecheck
dataclasses
subpath
accessor
accessors
4 changes: 4 additions & 0 deletions include/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class PyTreeTypeRegistry {
py::function flatten_func{};
// A function with signature: (metadata, iterable) -> object
py::function unflatten_func{};
// The Python type object for the path entry class.
py::object path_entry_type{};
};

using RegistrationPtr = std::shared_ptr<const Registration>;
Expand All @@ -82,6 +84,7 @@ class PyTreeTypeRegistry {
static void Register(const py::object &cls,
const py::function &flatten_func,
const py::function &unflatten_func,
const py::object &path_entry_type,
const std::string &registry_namespace = "");

static void Unregister(const py::object &cls, const std::string &registry_namespace = "");
Expand All @@ -104,6 +107,7 @@ class PyTreeTypeRegistry {
static void RegisterImpl(const py::object &cls,
const py::function &flatten_func,
const py::function &unflatten_func,
const py::object &path_entry_type,
const std::string &registry_namespace);

template <bool NoneIsLeaf>
Expand Down
21 changes: 12 additions & 9 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ bool AllLeavesImpl(const py::iterable &iterable,
const std::optional<py::function> &leaf_predicate,
const std::string &registry_namespace);

py::module_ GetCxxModule(const std::optional<py::module_> &module = std::nullopt);

// A PyTreeSpec describes the tree structure of a PyTree. A PyTree is a tree of Python values, where
// the interior nodes are tuples, lists, dictionaries, or user-defined containers, and the leaves
// are other objects.
Expand Down Expand Up @@ -114,8 +116,8 @@ class PyTreeSpec {
// Return paths to all leaves in the PyTreeSpec.
[[nodiscard]] std::vector<py::tuple> Paths() const;

// Return a list of tuples of (type, entry) pairs to all leaves in the PyTreeSpec.
[[nodiscard]] std::vector<py::tuple> TypedPaths() const;
// Return a list of accessors to all leaves in the PyTreeSpec.
[[nodiscard]] std::vector<py::object> Accessors() const;

// Return one-level entries of the PyTreeSpec to its children.
[[nodiscard]] py::list Entries() const;
Expand Down Expand Up @@ -208,10 +210,8 @@ class PyTreeSpec {
py::object node_data{};

// The tuple of path entries.
// This is optional, if not specified, `range(arity)` is used.
// For a sequence, contains the index of the element.
// For a mapping, contains the key of the element.
// For a Custom type, contains the path entries returned by the `flatten_func` function.
// This is optional, if not specified, `range(arity)` is used.
py::object node_entries{};

// Custom type registration. Must be null for non-custom types.
Expand Down Expand Up @@ -245,6 +245,9 @@ class PyTreeSpec {
const py::object *children,
const size_t &num_children);

// Helper that identifies the path entry class for a node.
static py::object GetPathEntryClass(const Node &node);

// Recursive helper used to implement Flatten().
bool FlattenInto(const py::handle &handle,
std::vector<py::object> &leaves, // NOLINT[runtime/references]
Expand Down Expand Up @@ -293,10 +296,10 @@ class PyTreeSpec {
const ssize_t &depth) const;

template <typename Span, typename Stack>
[[nodiscard]] ssize_t TypedPathsImpl(Span &typed_paths, // NOLINT[runtime/references]
Stack &stack, // NOLINT[runtime/references]
const ssize_t &pos,
const ssize_t &depth) const;
[[nodiscard]] ssize_t AccessorsImpl(Span &accessors, // NOLINT[runtime/references]
Stack &stack, // NOLINT[runtime/references]
const ssize_t &pos,
const ssize_t &depth) const;

[[nodiscard]] std::string ToStringImpl() const;

Expand Down
15 changes: 13 additions & 2 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@ import enum
from collections.abc import Callable, Iterable, Iterator
from typing import Any

from optree.typing import CustomTreeNode, FlattenFunc, MetaData, PyTree, T, U, UnflattenFunc
from optree.typing import (
CustomTreeNode,
FlattenFunc,
MetaData,
PyTree,
PyTreeAccessor,
PyTreeEntry,
T,
U,
UnflattenFunc,
)

class InternalError(RuntimeError): ...

Expand Down Expand Up @@ -109,7 +119,7 @@ class PyTreeSpec:
leaves: Iterable[T],
) -> U: ...
def paths(self) -> list[tuple[Any, ...]]: ...
def typed_paths(self) -> list[tuple[tuple[Any, builtins.type[Any], PyTreeKind], ...]]: ...
def accessors(self) -> list[PyTreeAccessor]: ...
def entries(self) -> list[Any]: ...
def entry(self, index: int) -> Any: ...
def children(self) -> list[PyTreeSpec]: ...
Expand Down Expand Up @@ -141,6 +151,7 @@ def register_node(
cls: type[CustomTreeNode[T]],
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
path_entry_type: type[PyTreeEntry],
namespace: str = '',
) -> None: ...
def unregister_node(
Expand Down
41 changes: 29 additions & 12 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
"""OpTree: Optimized PyTree Utilities."""

from optree import integration, typing
from optree.accessor import (
DataclassEntry,
FlattenedEntry,
GetAttrEntry,
GetItemEntry,
MappingEntry,
NamedTupleEntry,
PyTreeAccessor,
PyTreeEntry,
SequenceEntry,
StructSequenceEntry,
)
from optree.ops import (
MAX_RECURSION_DEPTH,
NONE_IS_LEAF,
Expand All @@ -23,6 +35,7 @@
broadcast_common,
broadcast_prefix,
prefix_errors,
tree_accessors,
tree_all,
tree_any,
tree_broadcast_common,
Expand All @@ -31,8 +44,8 @@
tree_broadcast_prefix,
tree_flatten,
tree_flatten_one_level,
tree_flatten_with_accessor,
tree_flatten_with_path,
tree_flatten_with_typed_path,
tree_is_leaf,
tree_iter,
tree_leaves,
Expand All @@ -50,8 +63,8 @@
tree_transpose,
tree_transpose_map,
tree_transpose_map_with_path,
tree_typed_paths,
tree_unflatten,
treespec_accessors,
treespec_child,
treespec_children,
treespec_defaultdict,
Expand All @@ -72,19 +85,15 @@
treespec_paths,
treespec_structseq,
treespec_tuple,
treespec_typed_paths,
)
from optree.registry import (
AttributeKeyPathEntry,
GetitemKeyPathEntry,
Partial,
PyTreeAccessor,
PyTreePathEntry,
register_keypaths,
register_pytree_node,
register_pytree_node_class,
unregister_pytree_node,
register_pytree_path_handler,
)
from optree.typing import (
CustomTreeNode,
Expand Down Expand Up @@ -114,13 +123,13 @@
'NONE_IS_LEAF',
'tree_flatten',
'tree_flatten_with_path',
'tree_flatten_with_typed_path',
'tree_flatten_with_accessor',
'tree_unflatten',
'tree_iter',
'tree_leaves',
'tree_structure',
'tree_paths',
'tree_typed_paths',
'tree_accessors',
'tree_is_leaf',
'all_leaves',
'tree_map',
Expand All @@ -146,7 +155,7 @@
'tree_flatten_one_level',
'prefix_errors',
'treespec_paths',
'treespec_typed_paths',
'treespec_accessors',
'treespec_entries',
'treespec_entry',
'treespec_children',
Expand All @@ -166,6 +175,17 @@
'treespec_deque',
'treespec_structseq',
'treespec_from_collection',
# Accessor
'PyTreeEntry',
'GetAttrEntry',
'GetItemEntry',
'FlattenedEntry',
'SequenceEntry',
'MappingEntry',
'NamedTupleEntry',
'StructSequenceEntry',
'DataclassEntry',
'PyTreeAccessor',
# Registry
'register_pytree_node',
'register_pytree_node_class',
Expand All @@ -174,9 +194,6 @@
'register_keypaths',
'AttributeKeyPathEntry',
'GetitemKeyPathEntry',
'PyTreeAccessor',
'PyTreePathEntry',
'register_pytree_path_handler',
# Typing
'PyTreeSpec',
'PyTreeDef',
Expand Down
Loading

0 comments on commit 1d03768

Please sign in to comment.