Skip to content

Commit

Permalink
Merge branch 'main' into typed-path
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Feb 25, 2024
2 parents 2b3e9f6 + 06c0b67 commit e622045
Show file tree
Hide file tree
Showing 10 changed files with 148 additions and 160 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ jobs:
make pytest
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v3
uses: codecov/codecov-action@v4
if: ${{ matrix.os == 'ubuntu-latest' }}
with:
token: ${{ secrets.CODECOV_TOKEN }}
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
hooks:
- id: clang-format
- repo: /~https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
rev: v0.2.2
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -39,11 +39,11 @@ repos:
hooks:
- id: isort
- repo: /~https://github.com/psf/black
rev: 23.12.1
rev: 24.2.0
hooks:
- id: black
- repo: /~https://github.com/asottile/pyupgrade
rev: v3.15.0
rev: v3.15.1
hooks:
- id: pyupgrade
args: [--py37-plus]
Expand Down
54 changes: 28 additions & 26 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,48 +293,52 @@ inline void SET_ITEM<py::list>(const py::handle& container,
PyList_SET_ITEM(container.ptr(), py::ssize_t_cast(item), value.inc_ref().ptr());
}

inline std::string PyRepr(const py::handle& object) {
return static_cast<std::string>(py::repr(object));
}
inline std::string PyRepr(const std::string& string) {
return static_cast<std::string>(py::repr(py::str(string)));
}

template <typename PyType>
inline void AssertExact(const py::handle& object) {
if (!py::isinstance<PyType>(object)) [[unlikely]] {
std::ostringstream oss{};
oss << "Expected an instance of " << typeid(PyType).name() << ", got "
<< static_cast<std::string>(py::repr(object)) << ".";
oss << "Expected an instance of " << typeid(PyType).name() << ", got " << PyRepr(object)
<< ".";
throw py::value_error(oss.str());
}
}
template <>
inline void AssertExact<py::list>(const py::handle& object) {
if (!PyList_CheckExact(object.ptr())) [[unlikely]] {
throw py::value_error("Expected an instance of list, got " +
static_cast<std::string>(py::repr(object)) + ".");
throw py::value_error("Expected an instance of list, got " + PyRepr(object) + ".");
}
}
template <>
inline void AssertExact<py::tuple>(const py::handle& object) {
if (!PyTuple_CheckExact(object.ptr())) [[unlikely]] {
throw py::value_error("Expected an instance of tuple, got " +
static_cast<std::string>(py::repr(object)) + ".");
throw py::value_error("Expected an instance of tuple, got " + PyRepr(object) + ".");
}
}
template <>
inline void AssertExact<py::dict>(const py::handle& object) {
if (!PyDict_CheckExact(object.ptr())) [[unlikely]] {
throw py::value_error("Expected an instance of dict, got " +
static_cast<std::string>(py::repr(object)) + ".");
throw py::value_error("Expected an instance of dict, got " + PyRepr(object) + ".");
}
}

inline void AssertExactOrderedDict(const py::handle& object) {
if (!py::type::handle_of(object).is(PyOrderedDictTypeObject)) [[unlikely]] {
throw py::value_error("Expected an instance of collections.OrderedDict, got " +
static_cast<std::string>(py::repr(object)) + ".");
PyRepr(object) + ".");
}
}

inline void AssertExactDefaultDict(const py::handle& object) {
if (!py::type::handle_of(object).is(PyDefaultDictTypeObject)) [[unlikely]] {
throw py::value_error("Expected an instance of collections.defaultdict, got " +
static_cast<std::string>(py::repr(object)) + ".");
PyRepr(object) + ".");
}
}

Expand All @@ -345,14 +349,14 @@ inline void AssertExactStandardDict(const py::handle& object) {
throw py::value_error(
"Expected an instance of dict, collections.OrderedDict, or collections.defaultdict, "
"got " +
static_cast<std::string>(py::repr(object)) + ".");
PyRepr(object) + ".");
}
}

inline void AssertExactDeque(const py::handle& object) {
if (!py::type::handle_of(object).is(PyDequeTypeObject)) [[unlikely]] {
throw py::value_error("Expected an instance of collections.deque, got " +
static_cast<std::string>(py::repr(object)) + ".");
throw py::value_error("Expected an instance of collections.deque, got " + PyRepr(object) +
".");
}
}

Expand Down Expand Up @@ -407,22 +411,22 @@ inline bool IsNamedTuple(const py::handle& object) {
inline void AssertExactNamedTuple(const py::handle& object) {
if (!IsNamedTupleInstance(object)) [[unlikely]] {
throw py::value_error("Expected an instance of collections.namedtuple, got " +
static_cast<std::string>(py::repr(object)) + ".");
PyRepr(object) + ".");
}
}
inline py::tuple NamedTupleGetFields(const py::handle& object) {
py::handle type;
if (PyType_Check(object.ptr())) [[unlikely]] {
type = object;
if (!IsNamedTupleClass(type)) [[unlikely]] {
throw py::type_error("Expected a collections.namedtuple type, got " +
static_cast<std::string>(py::repr(object)) + ".");
throw py::type_error("Expected a collections.namedtuple type, got " + PyRepr(object) +
".");
}
} else [[likely]] {
type = py::type::handle_of(object);
if (!IsNamedTupleClass(type)) [[unlikely]] {
throw py::type_error("Expected an instance of collections.namedtuple type, got " +
static_cast<std::string>(py::repr(object)) + ".");
PyRepr(object) + ".");
}
}
return py::getattr(type, Py_Get_ID(_fields));
Expand Down Expand Up @@ -470,22 +474,21 @@ inline bool IsStructSequence(const py::handle& object) {
inline void AssertExactStructSequence(const py::handle& object) {
if (!IsStructSequenceInstance(object)) [[unlikely]] {
throw py::value_error("Expected an instance of PyStructSequence type, got " +
static_cast<std::string>(py::repr(object)) + ".");
PyRepr(object) + ".");
}
}
inline py::tuple StructSequenceGetFields(const py::handle& object) {
py::handle type;
if (PyType_Check(object.ptr())) [[unlikely]] {
type = object;
if (!IsStructSequenceClass(type)) [[unlikely]] {
throw py::type_error("Expected a PyStructSequence type, got " +
static_cast<std::string>(py::repr(object)) + ".");
throw py::type_error("Expected a PyStructSequence type, got " + PyRepr(object) + ".");
}
} else [[likely]] {
type = py::type::handle_of(object);
if (!IsStructSequenceClass(type)) [[unlikely]] {
throw py::type_error("Expected an instance of PyStructSequence type, got " +
static_cast<std::string>(py::repr(object)) + ".");
PyRepr(object) + ".");
}
}

Expand Down Expand Up @@ -513,11 +516,10 @@ inline void TotalOrderSort(py::list& list) { // NOLINT[runtime/references]
// Sort with `(f'{o.__class__.__module__}.{o.__class__.__qualname__}', o)`
auto sort_key_fn = py::cpp_function([](const py::object& o) {
py::handle t = py::type::handle_of(o);
py::str qualname{static_cast<std::string>(
py::getattr(t, Py_Get_ID(__module__)).cast<py::str>()) +
"." +
static_cast<std::string>(
py::getattr(t, Py_Get_ID(__qualname__)).cast<py::str>())};
py::str qualname{
static_cast<std::string>(py::str(py::getattr(t, Py_Get_ID(__module__)))) +
"." +
static_cast<std::string>(py::str(py::getattr(t, Py_Get_ID(__qualname__))))};
return py::make_tuple(qualname, o);
});
py::getattr(list, Py_Get_ID(sort))(py::arg("key") = sort_key_fn);
Expand Down
2 changes: 1 addition & 1 deletion optree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:

try: # noqa: SIM105 # pragma: no cover
# pylint: disable=ungrouped-imports
from jax._src.util import HashablePartial # type: ignore[assignment]
from jax._src.util import HashablePartial # type: ignore[assignment] # noqa: F811,RUF100
except ImportError: # pragma: no cover
pass

Expand Down
6 changes: 4 additions & 2 deletions optree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@

from __future__ import annotations

from typing import Any, Callable, Iterable, Sequence, overload
from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, overload

from optree.typing import S, T, U

if TYPE_CHECKING:
from optree.typing import S, T, U


def total_order_sorted(
Expand Down
54 changes: 33 additions & 21 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ test-command = """make -C "{project}" test PYTHON=python"""
# Linter tools #################################################################

[tool.black]
safe = true
line-length = 100
skip-string-normalization = true
target-version = ["py37"]
Expand All @@ -136,14 +135,14 @@ lines_after_imports = 2
multi_line_output = 3

[tool.mypy]
python_version = 3.8
python_version = "3.8"
pretty = true
show_error_codes = true
show_error_context = true
show_traceback = true
allow_redefinition = true
check_untyped_defs = true
disallow_incomplete_defs = false
disallow_incomplete_defs = true
disallow_untyped_defs = false
ignore_missing_imports = true
no_implicit_optional = true
Expand All @@ -167,8 +166,10 @@ ignore-words = "docs/source/spelling_wordlist.txt"
[tool.ruff]
target-version = "py37"
line-length = 100
show-source = true
output-format = "full"
src = ["optree", "tests"]

[tool.ruff.lint]
select = [
"E", "W", # pycodestyle
"F", # pyflakes
Expand All @@ -182,14 +183,21 @@ select = [
"COM", # flake8-commas
"C4", # flake8-comprehensions
"EXE", # flake8-executable
"FA", # flake8-future-annotations
"LOG", # flake8-logging
"ISC", # flake8-implicit-str-concat
"INP", # flake8-no-pep420
"PIE", # flake8-pie
"PYI", # flake8-pyi
"Q", # flake8-quotes
"RSE", # flake8-raise
"RET", # flake8-return
"SIM", # flake8-simplify
"TID", # flake8-tidy-imports
"TCH", # flake8-type-checking
"PERF", # perflint
"FURB", # refurb
"TRY", # tryceratops
"RUF", # ruff
]
ignore = [
Expand All @@ -210,41 +218,45 @@ ignore = [
# S101: use of `assert` detected
# internal use and may never raise at runtime
"S101",
# PLR0402: use from {module} import {name} in lieu of alias
# use alias for import convention (e.g., `import torch.nn as nn`)
"PLR0402",
# TRY003: avoid specifying long messages outside the exception class
# long messages are necessary for clarity
"TRY003",
]
typing-modules = ["optree.typing"]

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = [
"F401", # unused-import
"F401", # unused-import
]
"optree/typing.py" = [
"E402", # module-import-not-at-top-of-file
"F722", # forward-annotation-syntax-error
"F811", # redefined-while-unused
"E402", # module-import-not-at-top-of-file
"F722", # forward-annotation-syntax-error
"F811", # redefined-while-unused
]
"setup.py" = [
"ANN", # flake8-annotations
"ANN", # flake8-annotations
]
"tests/**/*.py" = [
"ANN", # flake8-annotations
"S", # flake8-bandit
"BLE", # flake8-blind-except
"SIM", # flake8-simplify
"E402", # module-import-not-at-top-of-file
"ANN", # flake8-annotations
"S", # flake8-bandit
"BLE", # flake8-blind-except
"SIM", # flake8-simplify
"INP001", # flake8-no-pep420
"E402", # module-import-not-at-top-of-file
]
"docs/source/conf.py" = [
"INP001", # flake8-no-pep420
]

[tool.ruff.flake8-annotations]
[tool.ruff.lint.flake8-annotations]
allow-star-arg-any = true

[tool.ruff.flake8-quotes]
[tool.ruff.lint.flake8-quotes]
docstring-quotes = "double"
multiline-quotes = "double"
inline-quotes = "single"

[tool.ruff.flake8-tidy-imports]
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.pytest.ini_options]
Expand Down
Loading

0 comments on commit e622045

Please sign in to comment.