Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add shortcut module optree.pytree and optree.treespec #189

Merged
merged 18 commits into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Support using system `cmake` executable during setup by [@mgorny](/~https://github.com/mgorny) in [#188](/~https://github.com/metaopt/optree/pull/188).
- Add module `optree.pytree` and `optree.treespec` by [@lqhuang](/~https://github.com/lqhuang) in [#189](/~https://github.com/metaopt/optree/pull/189).

### Changed

Expand Down
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,23 @@ True
True
```

> [!NOTE]
>
> Since `v0.14.1`, a new namespace `optree.pytree` is introduced as aliases for `optree.tree_*` functions. The following examples are equivalent to the above:
>
> ```python
> import optree.pytree as pt
> >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': 5, 'd': 6}
> >>> pt.flatten(tree)
> ([1, 2, 3, 4, 5, 6], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}))
> >>> pt.flatten(1)
> ([1], PyTreeSpec(*))
> >>> pt.flatten(None)
> ([], PyTreeSpec(None))
> >>> pt.leaves({'a': [1, 2], 'b': [3]}) == optree.tree_leaves({'b': [3], 'a': [1, 2]})
> >>> pt.structure({'a': [1, 2], 'b': [3]}) == optree.tree_structure({'b': [3], 'a': [1, 2]})
> ```

### Tree Nodes and Leaves

A tree is a collection of non-leaf nodes and leaf nodes, where the leaf nodes have no children to flatten.
Expand Down
2 changes: 2 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ OpTree: Optimized PyTree Utilities
:maxdepth: 2

ops.rst
pytree.rst
treespec.rst

.. toctree::
:maxdepth: 1
Expand Down
6 changes: 6 additions & 0 deletions docs/source/pytree.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
PyTree Operations
=================

.. currentmodule:: optree.pytree

.. automodule:: optree.pytree
6 changes: 6 additions & 0 deletions docs/source/treespec.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
TreeSpec Constructor API
========================

.. currentmodule:: optree.treespec

.. automodule:: optree.treespec
107 changes: 107 additions & 0 deletions optree/pytree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for working with ``PyTree``.

The :mod:`optree.pytree` namespace contains aliases of utilities from ``optree.tree_``.

>>> import optree.pytree as pt
...
...
>>> import optree.pytree as pytree
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> leaves, treespec = pytree.flatten(tree)
>>> leaves, treespec # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> tree == pytree.unflatten(treespec, leaves)
True

.. versionadded:: 0.14.1
"""

from __future__ import annotations

from optree.ops import tree_accessors as accessors
from optree.ops import tree_all as all
from optree.ops import tree_any as any
from optree.ops import tree_broadcast_common as broadcast_common
from optree.ops import tree_broadcast_map as broadcast_map
from optree.ops import tree_broadcast_map_with_accessor as broadcast_map_with_accessor
from optree.ops import tree_broadcast_map_with_path as broadcast_map_with_path
from optree.ops import tree_broadcast_prefix as broadcast_prefix
from optree.ops import tree_flatten as flatten
from optree.ops import tree_flatten_one_level as flatten_one_level
from optree.ops import tree_flatten_with_accessor as flatten_with_accessor
from optree.ops import tree_flatten_with_path as flatten_with_path
from optree.ops import tree_is_leaf as is_leaf
from optree.ops import tree_iter as iter
from optree.ops import tree_leaves as leaves
from optree.ops import tree_map as map
from optree.ops import tree_map_ as map_
from optree.ops import tree_map_with_accessor as map_with_accessor
from optree.ops import tree_map_with_accessor_ as map_with_accessor_
from optree.ops import tree_map_with_path as map_with_path
from optree.ops import tree_map_with_path_ as map_with_path_
from optree.ops import tree_max as max
from optree.ops import tree_min as min
from optree.ops import tree_paths as paths
from optree.ops import tree_reduce as reduce
from optree.ops import tree_replace_nones as replace_nones
from optree.ops import tree_structure as structure
from optree.ops import tree_sum as sum
from optree.ops import tree_transpose as transpose
from optree.ops import tree_transpose_map as transpose_map
from optree.ops import tree_transpose_map_with_accessor as transpose_map_with_accessor
from optree.ops import tree_transpose_map_with_path as transpose_map_with_path
from optree.ops import tree_unflatten as unflatten


__all__ = [
'flatten',
'unflatten',
'flatten_one_level',
'flatten_with_path',
'flatten_with_accessor',
'leaves',
'structure',
'paths',
'accessors',
'is_leaf',
'max',
'min',
'all',
'any',
'iter',
'sum',
'reduce',
'map',
'map_',
'map_with_path',
'map_with_path_',
'map_with_accessor',
'map_with_accessor_',
'replace_nones',
'transpose',
'transpose_map',
'transpose_map_with_path',
'transpose_map_with_accessor',
'broadcast_map',
'broadcast_map_with_path',
'broadcast_map_with_accessor',
'broadcast_prefix',
'broadcast_common',
]
55 changes: 55 additions & 0 deletions optree/treespec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2022-2025 MetaOPT Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The :mod:`optree.treespec` namespace contains constructors for ``TreeSpec`` class.

>>> import optree.treespec as ts
>>> ts.leaf()
PyTreeSpec(*)
>>> ts.none()
PyTreeSpec(None)
>>> ts.dict({'a': ts.leaf(), 'b': ts.leaf()})
PyTreeSpec({'a': *, 'b': *})

.. versionadded:: 0.14.1
"""

from __future__ import annotations

from optree.ops import treespec_defaultdict as defaultdict
from optree.ops import treespec_deque as deque
from optree.ops import treespec_dict as dict
from optree.ops import treespec_from_collection as from_collection
from optree.ops import treespec_leaf as leaf
from optree.ops import treespec_list as list
from optree.ops import treespec_namedtuple as namedtuple
from optree.ops import treespec_none as none
from optree.ops import treespec_ordereddict as ordereddict
from optree.ops import treespec_structseq as structseq
from optree.ops import treespec_tuple as tuple


__all__ = [
'leaf',
'none',
'tuple',
'list',
'dict',
'namedtuple',
'ordereddict',
'defaultdict',
'deque',
'structseq',
'from_collection',
]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ main.ignore-paths = ["^_C/$", "^tests/$"]
basic.good-names = []
design.max-args = 7
format.max-line-length = 120
"messages control".disable = ["duplicate-code", "consider-using-from-import"]
"messages control".disable = ["duplicate-code", "consider-using-from-import"]
"messages control".enable = ["c-extension-no-member"]
spelling.spelling-dict = "en_US"
spelling.spelling-private-dict-file = "docs/source/spelling_wordlist.txt"
Expand Down
20 changes: 20 additions & 0 deletions tests/test_shortcut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import optree
import optree.pytree as pytree
import optree.treespec as treespec


def test_pytree_reexports():
assert set(pytree.__all__) == {
name[len('tree_') :] for name in optree.__all__ if name.startswith('tree_')
}

for name in pytree.__all__:
assert getattr(pytree, name) is getattr(optree, f'tree_{name}')


def test_treespec_reexports():
# Not all `treespec` functions are re-exported,
# only test functions exist in `optree/treespec.py` .

for name in treespec.__all__:
assert getattr(treespec, name) is getattr(optree, f'treespec_{name}')
Loading