Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(registry): add context manager to temporarily set the dictionary sorting mode #147

Merged
merged 4 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
if-no-files-found: error

build-wheels:
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }}
name: Build wheels for Python ${{ matrix.python-version }} on ${{ matrix.os }} (${{ matrix.archs }})
runs-on: ${{ matrix.os }}
needs: [build-sdist]
if: github.repository == 'metaopt/optree' && (github.event_name != 'push' || startsWith(github.ref, 'refs/tags/'))
Expand All @@ -92,14 +92,65 @@ jobs:
os: [ubuntu-latest, windows-latest, macos-latest]
python-version:
["3.7", "3.8", "3.9", "3.10", "3.11", "3.12", "pypy3.9", "pypy3.10"]
archs: [
# Generic
"auto",
# Linux
"aarch64",
"ppc64le",
"s390x",
# Windows
"ARM64",
]
include:
- os: macos-13
python-version: "3.7"
archs: "auto"
exclude:
- os: ubuntu-latest
archs: "ARM64"
- os: windows-latest
archs: "aarch64"
- os: windows-latest
archs: "ppc64le"
- os: windows-latest
archs: "s390x"
- os: macos-latest
python-version: "3.7" # Python 3.7 does not support macOS ARM64
archs: "aarch64"
- os: macos-latest
archs: "ppc64le"
- os: macos-latest
archs: "s390x"
- os: macos-latest
archs: "ARM64"
- os: ubuntu-latest
python-version: "pypy3.9"
archs: "ppc64le"
- os: ubuntu-latest
python-version: "pypy3.10"
archs: "ppc64le"
- os: ubuntu-latest
python-version: "pypy3.9"
archs: "s390x"
- os: ubuntu-latest
python-version: "pypy3.10"
archs: "s390x"
- os: windows-latest
python-version: "3.7"
archs: "ARM64"
- os: windows-latest
python-version: "3.8"
archs: "ARM64"
- os: windows-latest
python-version: "pypy3.9"
archs: "ARM64"
- os: windows-latest
python-version: "pypy3.10"
archs: "ARM64"
- os: macos-latest
python-version: "3.7"
fail-fast: false
timeout-minutes: 60
timeout-minutes: 120
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down Expand Up @@ -139,17 +190,16 @@ jobs:
uses: pypa/cibuildwheel@v2.19
env:
CIBW_BUILD: ${{ env.CIBW_BUILD }}
CIBW_ARCHS_LINUX: auto aarch64 ppc64le s390x
CIBW_ARCHS_WINDOWS: auto ARM64
CIBW_ARCHS_MACOS: x86_64 arm64 universal2
CIBW_ARCHS: ${{ matrix.archs }}
CIBW_ARCHS_MACOS: ${{ matrix.archs }} universal2
with:
package-dir: .
output-dir: wheelhouse
config-file: "{package}/pyproject.toml"

- uses: actions/upload-artifact@v4
with:
name: wheels-${{ matrix.python-version }}-${{ matrix.os }}
name: wheels-${{ matrix.python-version }}-${{ matrix.os }}-${{ matrix.archs }}
path: wheelhouse/*.whl
if-no-files-found: error

Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- Add context manager to temporarily set the dictionary sorting mode by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#147](/~https://github.com/metaopt/optree/pull/147).
- Add PyPy support by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#145](/~https://github.com/metaopt/optree/pull/145).
- Add 32-bit wheels for Linux and Windows by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#141](/~https://github.com/metaopt/optree/pull/141).
- Add Linux ppc64le and s390x wheels by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#138](/~https://github.com/metaopt/optree/pull/138).
Expand All @@ -21,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Use `stable` tag instead of 2.12.0 for `pybind11` version by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#146](/~https://github.com/metaopt/optree/pull/146).
- Refactor the raw import statement in `setup.py` with `importlib` utilities by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#135](/~https://github.com/metaopt/optree/pull/135).
- Update minimal version of `typing-extensions` to 4.5.0 for `typing_extensions.deprecated` by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#134](/~https://github.com/metaopt/optree/pull/134).
- Update string representation for `OrderedDict` by [@XuehaiPan](/~https://github.com/XuehaiPan) in [#133](/~https://github.com/metaopt/optree/pull/133).
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ pytest: pytest-install
$(PYTHON) -m pytest --version
cd tests && $(PYTHON) -X dev -c 'import $(PROJECT_PATH)' && \
$(PYTHON) -X dev -c 'import $(PROJECT_PATH)._C; print(f"GLIBCXX_USE_CXX11_ABI={$(PROJECT_PATH)._C.GLIBCXX_USE_CXX11_ABI}")' && \
$(PYTHON) -X dev -m pytest --verbose --color=yes \
$(PYTHON) -X dev -m pytest --verbose --color=yes --durations=0 --showlocals \
--cov="$(PROJECT_PATH)" --cov-config=.coveragerc --cov-report=xml --cov-report=term-missing \
$(PYTESTOPTS) .

Expand Down Expand Up @@ -152,7 +152,7 @@ mypy: mypy-install

xdoctest: xdoctest-install
$(PYTHON) -m xdoctest --version
$(PYTHON) -m xdoctest $(PROJECT_PATH)
$(PYTHON) -m xdoctest --global-exec "from optree import *" $(PROJECT_PATH)

doctest: xdoctest

Expand Down
2 changes: 2 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Tree Manipulation Functions

.. autosummary::

dict_insertion_ordered
tree_flatten
tree_flatten_with_path
tree_flatten_with_accessor
Expand Down Expand Up @@ -55,6 +56,7 @@ Tree Manipulation Functions
tree_flatten_one_level
prefix_errors

.. autofunction:: dict_insertion_ordered
.. autofunction:: tree_flatten
.. autofunction:: tree_flatten_with_path
.. autofunction:: tree_flatten_with_accessor
Expand Down
39 changes: 34 additions & 5 deletions include/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,25 @@ class PyTreeSpec {
const bool &none_is_leaf = false,
const std::string &registry_namespace = "");

// Check if should preserve the insertion order of the dictionary keys during flattening.
inline static bool IsDictInsertionOrdered(const std::string &registry_namespace,
const bool &inherit_global_namespace = true) {
return (sm_is_dict_insertion_ordered.find(registry_namespace) !=
sm_is_dict_insertion_ordered.end()) ||
(inherit_global_namespace &&
sm_is_dict_insertion_ordered.find("") != sm_is_dict_insertion_ordered.end());
}

// Set the namespace to preserve the insertion order of the dictionary keys during flattening.
inline static void SetDictInsertionOrdered(const bool &mode,
const std::string &registry_namespace) {
if (mode) [[likely]] {
sm_is_dict_insertion_ordered.insert(registry_namespace);
} else [[unlikely]] {
sm_is_dict_insertion_ordered.erase(registry_namespace);
}
}

private:
using RegistrationPtr = PyTreeTypeRegistry::RegistrationPtr;

Expand Down Expand Up @@ -266,7 +285,7 @@ class PyTreeSpec {
const bool &none_is_leaf,
const std::string &registry_namespace);

template <bool NoneIsLeaf, typename Span>
template <bool NoneIsLeaf, bool DictShouldBeSorted, typename Span>
bool FlattenIntoImpl(const py::handle &handle,
Span &leaves, // NOLINT[runtime/references]
const ssize_t &depth,
Expand All @@ -281,7 +300,11 @@ class PyTreeSpec {
const bool &none_is_leaf,
const std::string &registry_namespace);

template <bool NoneIsLeaf, typename LeafSpan, typename PathSpan, typename Stack>
template <bool NoneIsLeaf,
bool DictShouldBeSorted,
typename LeafSpan,
typename PathSpan,
typename Stack>
bool FlattenIntoWithPathImpl(const py::handle &handle,
LeafSpan &leaves, // NOLINT[runtime/references]
PathSpan &paths, // NOLINT[runtime/references]
Expand Down Expand Up @@ -329,6 +352,10 @@ class PyTreeSpec {
size_t operator()(const std::pair<const PyTreeSpec *, std::thread::id> &p) const;
};

// A set of namespaces that preserve the insertion order of the dictionary keys during
// flattening.
inline static std::unordered_set<std::string> sm_is_dict_insertion_ordered{};

// A set of (treespec, thread_id) pairs that are currently being represented as strings.
inline static std::unordered_set<std::pair<const PyTreeSpec *, std::thread::id>,
ThreadIndentTypeHash>
Expand All @@ -344,12 +371,13 @@ class PyTreeIter {
public:
PyTreeIter(const py::object &tree,
const std::optional<py::function> &leaf_predicate,
bool none_is_leaf,
std::string registry_namespace)
const bool &none_is_leaf,
const std::string &registry_namespace)
: m_agenda({std::make_pair(tree, 0)}),
m_leaf_predicate(leaf_predicate),
m_none_is_leaf(none_is_leaf),
m_namespace(std::move(registry_namespace)) {};
m_namespace(registry_namespace),
m_is_dict_insertion_ordered(PyTreeSpec::IsDictInsertionOrdered(registry_namespace)) {};

PyTreeIter() = delete;

Expand Down Expand Up @@ -377,6 +405,7 @@ class PyTreeIter {
std::optional<py::function> m_leaf_predicate;
bool m_none_is_leaf;
std::string m_namespace;
bool m_is_dict_insertion_ordered;

template <bool NoneIsLeaf>
[[nodiscard]] py::object NextImpl();
Expand Down
8 changes: 8 additions & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,11 @@ def unregister_node(
cls: type[CustomTreeNode[T]],
namespace: str = '',
) -> None: ...
def is_dict_insertion_ordered(
namespace: str = '',
inherit_global_namespace: bool = True,
) -> bool: ...
def set_dict_insertion_ordered(
mode: bool,
namespace: str = '',
) -> None: ...
2 changes: 2 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from optree.registry import (
AttributeKeyPathEntry,
GetitemKeyPathEntry,
dict_insertion_ordered,
register_keypaths,
register_pytree_node,
register_pytree_node_class,
Expand Down Expand Up @@ -200,6 +201,7 @@
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
# Typing
'PyTreeSpec',
'PyTreeDef',
Expand Down
89 changes: 89 additions & 0 deletions optree/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import contextlib
import dataclasses
import functools
import inspect
Expand All @@ -29,6 +30,7 @@
Any,
Callable,
ClassVar,
Generator,
Iterable,
NamedTuple,
Sequence,
Expand Down Expand Up @@ -70,6 +72,7 @@
'register_pytree_node',
'register_pytree_node_class',
'unregister_pytree_node',
'dict_insertion_ordered',
]


Expand Down Expand Up @@ -491,6 +494,17 @@ def _dict_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]:
return dict(safe_zip(keys, values))


def _dict_insertion_ordered_flatten(
dct: dict[KT, VT],
) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
keys, values = unzip2(dct.items())
return values, list(keys), keys


def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]:
return dict(safe_zip(keys, values))


def _ordereddict_flatten(
dct: OrderedDict[KT, VT],
) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]:
Expand All @@ -517,6 +531,21 @@ def _defaultdict_unflatten(
return defaultdict(default_factory, _dict_unflatten(keys, values))


def _defaultdict_insertion_ordered_flatten(
dct: defaultdict[KT, VT],
) -> tuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]:
values, keys, entries = _dict_insertion_ordered_flatten(dct)
return values, (dct.default_factory, keys), entries


def _defaultdict_insertion_ordered_unflatten(
metadata: tuple[Callable[[], VT], list[KT]],
values: Iterable[VT],
) -> defaultdict[KT, VT]:
default_factory, keys = metadata
return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values))


def _deque_flatten(deq: deque[T]) -> tuple[deque[T], int | None]:
return deq, deq.maxlen

Expand Down Expand Up @@ -566,6 +595,23 @@ def _pytree_node_registry_get(
handler = _NODETYPE_REGISTRY.get((namespace, cls))
if handler is not None:
return handler

if _C.is_dict_insertion_ordered(namespace):
if cls is dict:
return PyTreeNodeRegistryEntry(
dict,
_dict_insertion_ordered_flatten, # type: ignore[arg-type]
_dict_insertion_ordered_unflatten, # type: ignore[arg-type]
path_entry_type=MappingEntry,
)
if cls is defaultdict:
return PyTreeNodeRegistryEntry(
defaultdict,
_defaultdict_insertion_ordered_flatten, # type: ignore[arg-type]
_defaultdict_insertion_ordered_unflatten, # type: ignore[arg-type]
path_entry_type=MappingEntry,
)

handler = _NODETYPE_REGISTRY.get(cls)
if handler is not None:
return handler
Expand All @@ -580,6 +626,49 @@ def _pytree_node_registry_get(
del _pytree_node_registry_get


@contextlib.contextmanager
def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None, None, None]:
"""Context manager to temporarily set the dictionary sorting mode.

This context manager is used to temporarily set the dictionary sorting mode for a specific
namespace. The dictionary sorting mode is used to determine whether the keys of a dictionary
should be sorted or keeping the insertion order when flattening a pytree.

>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
>>> tree_flatten(tree) # doctest: +IGNORE_WHITESPACE
(
[1, 2, 3, 4, 5],
PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': None, 'd': *})
)
>>> with dict_insertion_ordered(True, namespace='some-namespace'): # doctest: +IGNORE_WHITESPACE
... tree_flatten(tree, namespace='some-namespace')
(
[2, 3, 4, 1, 5],
PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': None, 'd': *}, namespace='some-namespace')
)

Args:
mode (bool): The dictionary sorting mode to set.
namespace (str): The namespace to set the dictionary sorting mode for.
"""
if namespace is not __GLOBAL_NAMESPACE and not isinstance(namespace, str):
raise TypeError(f'The namespace must be a string, got {namespace!r}.')
if namespace == '':
raise ValueError('The namespace cannot be an empty string.')
if namespace is __GLOBAL_NAMESPACE:
namespace = ''

with __REGISTRY_LOCK:
prev = _C.is_dict_insertion_ordered(namespace, inherit_global_namespace=False)
_C.set_dict_insertion_ordered(bool(mode), namespace)

try:
yield
finally:
with __REGISTRY_LOCK:
_C.set_dict_insertion_ordered(prev, namespace)


####################################################################################################

warnings.filterwarnings('ignore', category=FutureWarning, module=__name__, append=True)
Expand Down
Loading
Loading