From a11336ffe1786b76b4b4f322a943054fe3172120 Mon Sep 17 00:00:00 2001 From: Jeremy Maitin-Shepard Date: Fri, 5 Jul 2024 22:37:13 -0700 Subject: [PATCH] Add support for Python type parameter lists Sphinx has partial support for type parameter lists: they are supported by the Python domain in signatures, but are not supported by autodoc. This adds the following support: - Sphinx Python domain for type parameter fields in docstrings, with sphinx.ext.napoleon support as well. - Support for type parameters as Sphinx objects, with cross-linking, like the existing support for function parameters as Sphinx objects. - Support in apigen for PEP 695 type parameters, and for displaying pre-PEP 695 separately-defined TypeVar types as PEP 695 type parameters. --- .../apidoc/object_description_options.py | 17 +- sphinx_immaterial/apidoc/object_toc.py | 1 - sphinx_immaterial/apidoc/python/apigen.py | 179 +++++++++++- .../apidoc/python/parameter_objects.py | 263 ++++++++++++++---- .../python/type_annotation_transforms.py | 25 +- .../apidoc/python/type_param_utils.py | 244 ++++++++++++++++ tests/python_type_param_utils_test.py | 42 +++ 7 files changed, 701 insertions(+), 70 deletions(-) create mode 100644 sphinx_immaterial/apidoc/python/type_param_utils.py create mode 100644 tests/python_type_param_utils_test.py diff --git a/sphinx_immaterial/apidoc/object_description_options.py b/sphinx_immaterial/apidoc/object_description_options.py index 240232e95..9392aa782 100644 --- a/sphinx_immaterial/apidoc/object_description_options.py +++ b/sphinx_immaterial/apidoc/object_description_options.py @@ -50,7 +50,22 @@ def format_object_description_tooltip( ("py:property", {"toc_icon_class": "alias", "toc_icon_text": "P"}), ("py:attribute", {"toc_icon_class": "alias", "toc_icon_text": "A"}), ("py:data", {"toc_icon_class": "alias", "toc_icon_text": "V"}), - ("py:parameter", {"toc_icon_class": "sub-data", "toc_icon_text": "p"}), + ( + "py:parameter", + { + "toc_icon_class": "sub-data", + "toc_icon_text": "p", + "generate_synopses": "first_sentence", + }, + ), + ( + "py:typeParameter", + { + "toc_icon_class": "alias", + "toc_icon_text": "T", + "generate_synopses": "first_sentence", + }, + ), ("c:member", {"toc_icon_class": "alias", "toc_icon_text": "V"}), ("c:var", {"toc_icon_class": "alias", "toc_icon_text": "V"}), ("c:function", {"toc_icon_class": "procedure", "toc_icon_text": "F"}), diff --git a/sphinx_immaterial/apidoc/object_toc.py b/sphinx_immaterial/apidoc/object_toc.py index 90c29bdea..9747e5037 100644 --- a/sphinx_immaterial/apidoc/object_toc.py +++ b/sphinx_immaterial/apidoc/object_toc.py @@ -73,7 +73,6 @@ def _make_section_from_field( source: docutils.nodes.field, ) -> Optional[docutils.nodes.section]: fieldname = cast(docutils.nodes.field_name, source[0]) - fieldbody = cast(docutils.nodes.field_body, source[1]) ids = fieldname["ids"] if not ids: # Not indexed diff --git a/sphinx_immaterial/apidoc/python/apigen.py b/sphinx_immaterial/apidoc/python/apigen.py index 0d1c01af1..b642cab31 100644 --- a/sphinx_immaterial/apidoc/python/apigen.py +++ b/sphinx_immaterial/apidoc/python/apigen.py @@ -20,6 +20,7 @@ import inspect import json import re +import typing from typing import ( List, Tuple, @@ -54,6 +55,7 @@ from .. import object_description_options from ... import sphinx_utils from .. import apigen_utils +from . import type_param_utils if sphinx.version_info >= (6, 1): stringify_annotation = sphinx.util.typing.stringify_annotation @@ -266,6 +268,10 @@ def _must_shorten(): if not added_ellipsis: added_ellipsis = True ellipsis_node = sphinx.addnodes.desc_sig_punctuation("", "...") + # When using the `sphinx_immaterial.apidoc.format_signatures` + # extension, replace the text of this node to make it valid Python + # syntax. + ellipsis_node["munged_text_for_formatting"] = "___" param = sphinx.addnodes.desc_parameter() param += ellipsis_node parameterlist += param @@ -295,6 +301,8 @@ class _MemberDocumenterEntry(NamedTuple): subscript: bool = False """Whether this is a "subscript" method to be shown with [] instead of ().""" + type_param_substitutions: Optional[type_param_utils.TypeParamSubstitutions] = None + @property def overload_suffix(self): if self.overload and self.overload.overload_id: @@ -312,6 +320,7 @@ class _ApiEntityMemberReference(NamedTuple): parent_canonical_object_name: str inherited: bool siblings: List["_ApiEntityMemberReference"] + type_param_substitutions: Optional[type_param_utils.TypeParamSubstitutions] @dataclasses.dataclass @@ -366,6 +375,11 @@ def overload_suffix(self) -> str: primary_entity: bool = True """Indicates if this is the primary sibling and should be documented.""" + type_params: Optional[Tuple[type_param_utils.TypeParam, ...]] = None + parent_type_params: Optional[ + Tuple[str, Tuple[type_param_utils.TypeParam, ...]] + ] = None + def _is_constructor_name(name: str) -> bool: return name in ("__init__", "__new__", "__class_getitem__") @@ -491,7 +505,7 @@ def _clean_init_signature(signode: sphinx.addnodes.desc_signature) -> None: Removes the return type (always None) and the self parameter (since these methods are displayed as the class name, without showing __init__). - :param node: Signature to modify in place. + :param signode: Signature to modify in place. """ # Remove first parameter. for param in signode.findall(condition=sphinx.addnodes.desc_parameter): @@ -510,7 +524,7 @@ def _clean_class_getitem_signature(signode: sphinx.addnodes.desc_signature) -> N Removes the `static` prefix since these methods are shown using the class name (i.e. as "subscript" constructors). - :param node: Signature to modify in place. + :param signode: Signature to modify in place. """ # Remove `static` prefix @@ -519,6 +533,60 @@ def _clean_class_getitem_signature(signode: sphinx.addnodes.desc_signature) -> N break +def _insert_parent_type_params( + env: sphinx.environment.BuildEnvironment, + signode: sphinx.addnodes.desc_signature, + parent_symbol: str, + parent_type_params: tuple[type_param_utils.TypeParam, ...], + is_constructor: bool, +) -> None: + def _make_type_list_node(): + tp_list = type_param_utils.stringify_type_params(parent_type_params) + tp_list_node = sphinx.domains.python._annotations._parse_type_list(tp_list, env) + for desc_param_node in tp_list_node.findall( + condition=sphinx.addnodes.desc_type_parameter + ): + desc_param_node[ + "sphinx_immaterial_type_param_symbol_prefix" + ] = parent_symbol + return tp_list_node + + if is_constructor: + for node in signode.findall(sphinx.addnodes.desc_name): + break + else: + return False + node.parent.insert(node.parent.index(node) + 1, _make_type_list_node()) + return True + + prev_name_node = None + for node in signode.findall(): + if isinstance(node, sphinx.addnodes.desc_addname): + prev_name_node = node + elif isinstance(node, sphinx.addnodes.desc_name): + break + else: + return False + + if prev_name_node is None: + return False + + index = prev_name_node.parent.index(prev_name_node) + prev_name_text = prev_name_node.astext().rstrip(".") + prev_name_node.replace_self( + sphinx.addnodes.desc_addname(prev_name_text, prev_name_text) + ) + tp_list_node = _make_type_list_node() + prev_name_node.parent.insert( + index + 1, + [ + _make_type_list_node(), + sphinx.addnodes.desc_addname(".", "."), + ], + ) + return True + + def _get_api_data( env: sphinx.environment.BuildEnvironment, ) -> _ApiData: @@ -573,6 +641,15 @@ def object_description_transform( obj_desc["classes"].append("summary") assert app.env is not None _summarize_signature(app.env, signode) + elif entity.parent_type_params: + # Insert additional type parameters + _insert_parent_type_params( + app.env, + signode, + entity.parent_type_params[0], + entity.parent_type_params[1], + is_constructor=_is_constructor_name(entity.documented_name), + ) base_classes = entity.base_classes if base_classes: @@ -587,6 +664,9 @@ def object_description_transform( signode += _parse_annotation(base_class, env) signode += sphinx.addnodes.desc_sig_punctuation("", ")") + if not summary: + _ensure_module_name_in_signature(signode) + if callback is not None: callback(contentnode) @@ -638,7 +718,24 @@ def object_description_transform( signatures: List[str] = [] for e, m in zip(all_entities, all_members): name = api_data.get_name_for_signature(e, m) - signatures.extend(name + sig for sig in e.signatures) + unnamed_signatures = e.signatures + if m is not None and m.type_param_substitutions: + unnamed_signatures = [ + type_param_utils.substitute_type_params( + sig, m.type_param_substitutions + ) + for sig in unnamed_signatures + ] + signatures.extend(name + sig for sig in unnamed_signatures) + + if ( + (m := all_members[0]) is not None + and "type" in options + and m.type_param_substitutions + ): + options["type"] = type_param_utils.substitute_type_params( + options["type"], all_members[0].type_param_substitutions + ) sphinx_utils.append_directive_to_stringlist( rst_input, @@ -850,10 +947,6 @@ def run(self) -> List[docutils.nodes.Node]: ), ) - for signode in objdesc.children[:-1]: - signode = cast(sphinx.addnodes.desc_signature, signode) - _ensure_module_name_in_signature(signode) - # Wrap in a section section = docutils.nodes.section() section["ids"].append("") @@ -1323,6 +1416,7 @@ def member_sort_key(entry): def _get_documenter_members( + app: sphinx.application.Sphinx, documenter: sphinx.ext.autodoc.Documenter, canonical_full_name: str, ) -> Iterator[_MemberDocumenterEntry]: @@ -1334,30 +1428,46 @@ def _get_documenter_members( seen_members: Set[str] = set() def _get_unseen_members( - members: Iterator[_MemberDocumenterEntry], is_inherited: bool + members: Iterator[_MemberDocumenterEntry], + is_inherited: bool, + type_param_substitutions: Optional[type_param_utils.TypeParamSubstitutions], ) -> Iterator[_MemberDocumenterEntry]: for member in members: overload_name = member.toc_title if overload_name in seen_members: continue seen_members.add(overload_name) - yield member._replace(is_inherited=is_inherited) + yield member._replace( + is_inherited=is_inherited, + type_param_substitutions=type_param_substitutions, + ) yield from _get_unseen_members( _get_documenter_direct_members( documenter, canonical_full_name=canonical_full_name ), is_inherited=False, + type_param_substitutions=None, ) if documenter.objtype != "class": return + base_class_type_param_substitutions = ( + type_param_utils.get_base_class_type_param_substitutions(documenter.object) + ) + for cls in inspect.getmro(documenter.object): if cls is documenter.object: continue - if cls.__module__ in ("builtins", "pybind11_builtins"): + skip_user = app.emit_firstresult("python-apigen-skip-base", object, cls) + if skip_user is True: continue + if skip_user is None: + if cls.__module__ in ("builtins", "pybind11_builtins"): + continue + if cls is typing.Generic: + continue class_name = f"{cls.__module__}::{cls.__qualname__}" parent_canonical_full_name = f"{cls.__module__}.{cls.__qualname__}" try: @@ -1373,6 +1483,7 @@ def _get_unseen_members( canonical_full_name=parent_canonical_full_name, ), is_inherited=True, + type_param_substitutions=base_class_type_param_substitutions.get(cls), ) except Exception as e: # pylint: disable=broad-except logger.warning( @@ -1474,8 +1585,10 @@ def _summarize_rst_content(content: List[str]) -> List[str]: class _ApiEntityCollector: def __init__( self, + app: sphinx.application.Sphinx, entities: Dict[str, _ApiEntity], ): + self.app = app self.entities = entities def collect_entity_recursively( @@ -1535,12 +1648,18 @@ def document_members(*args, **kwargs): base_classes: Optional[List[str]] = None + type_params = () + if isinstance(entry.documenter, sphinx.ext.autodoc.ClassDocumenter): + type_params = type_param_utils.get_class_type_params( + entry.documenter.object + ) + # By default (unless the `autodoc_class_signature` config option is # set to `"separated"`), autodoc will include the `__init__` # parameters in the signature. Since that convention does not work # well with this extension, we just bypass that here. - signatures = [""] + signatures = [type_param_utils.stringify_type_params(type_params)] if entry.documenter.config.python_apigen_show_base_classes: obj = entry.documenter.object @@ -1559,7 +1678,10 @@ def document_members(*args, **kwargs): base_classes = [ stringify_annotation(base) for base in base_list - if base is not object + if ( + base is not object + and typing.get_origin(base) is not typing.Generic + ) ] else: signatures = entry.documenter.format_signature().split("\n") @@ -1584,6 +1706,7 @@ def document_members(*args, **kwargs): overload_id=overload_id or "", base_classes=base_classes, primary_entity=primary_entity is None, + type_params=type_params, ) self.entities[canonical_object_name] = entity @@ -1611,7 +1734,7 @@ def collect_documenter_members( ] = {} for entry in _get_documenter_members( - documenter, canonical_full_name=canonical_object_name + self.app, documenter, canonical_full_name=canonical_object_name ): obj = None if isinstance( @@ -1638,7 +1761,8 @@ def collect_documenter_members( primary_sibling_member.canonical_object_name ] member_canonical_object_name = self.collect_entity_recursively( - entry, primary_entity=primary_sibling_entity + entry, + primary_entity=primary_sibling_entity, ) child = self.entities[member_canonical_object_name] member = _ApiEntityMemberReference( @@ -1647,6 +1771,7 @@ def collect_documenter_members( canonical_object_name=member_canonical_object_name, inherited=entry.is_inherited, siblings=[], + type_param_substitutions=entry.type_param_substitutions, ) if primary_sibling_member is not None: @@ -1732,6 +1857,29 @@ def parent_sort_key(parent_ref: _ApiEntityMemberReference): else: parent_documented_name = get_documented_full_name(parent_entity) entity.options["module"] = parent_entity.options["module"] + if parent_entity.type_params: + if entity.objtype != "method" or ( + not entity.options.get("classmethod") + and not entity.options.get("staticmethod") + ): + entity.parent_type_params = ( + parent_entity.object_name, + parent_entity.type_params, + ) + + # Resolve type parameters + if entity.objtype != "class": + for i, signature in enumerate(entity.signatures): + type_params = type_param_utils.get_type_params_from_signature(signature) + if entity.parent_type_params: + for param in entity.parent_type_params[1]: + type_params.pop(param.__name__, None) + if type_params: + entity.signatures[i] = ( + type_param_utils.stringify_type_params(type_params.values()) + + signature + ) + documented_full_name = parent_documented_name + "." + parent_ref.name entity.documented_full_name = documented_full_name entity.documented_name = parent_ref.name @@ -1788,6 +1936,7 @@ def _builder_inited(app: sphinx.application.Sphinx) -> None: name=module_name, ) _ApiEntityCollector( + app=app, entities=data.entities, ).collect_documenter_members( documenter=documenter, @@ -1951,4 +2100,6 @@ def setup(app: sphinx.application.Sphinx): app.add_config_value( "python_apigen_rst_epilog", types=(str,), default="", rebuild="env" ) + app.add_event("python-apigen-skip-base") + return {"parallel_read_safe": True, "parallel_write_safe": True} diff --git a/sphinx_immaterial/apidoc/python/parameter_objects.py b/sphinx_immaterial/apidoc/python/parameter_objects.py index 4fde71bbf..0317a17d0 100644 --- a/sphinx_immaterial/apidoc/python/parameter_objects.py +++ b/sphinx_immaterial/apidoc/python/parameter_objects.py @@ -1,16 +1,34 @@ -from typing import Optional, cast, List, Dict, Sequence, Tuple, Any, Iterator +from typing import ( + Optional, + cast, + List, + Dict, + Sequence, + Tuple, + Any, + Iterator, + TypeVar, + Literal, +) import docutils.nodes from sphinx.domains.python import PyTypedField from sphinx.domains.python import PythonDomain from sphinx.domains.python import PyObject +from sphinx.locale import _ import sphinx.util.logging +import sphinx.ext.autodoc +import sphinx.util.typing +import sphinx.ext.autodoc.typehints +import sphinx.util.inspect +import sphinx.ext.napoleon.docstring +from sphinx.ext.napoleon.docstring import GoogleDocstring + from . import annotation_style from .. import object_description_options from ... import sphinx_utils - logger = sphinx.util.logging.getLogger(__name__) @@ -205,7 +223,7 @@ def get_objects( self: PythonDomain, ) -> Iterator[Tuple[str, str, str, str, str, int]]: for obj in orig_get_objects(self): - if obj[2] != "parameter": + if obj[2] in ("parameter", "typeParameter"): yield obj else: yield ( @@ -223,7 +241,8 @@ def get_objects( def _add_parameter_links_to_signature( env: sphinx.environment.BuildEnvironment, signode: sphinx.addnodes.desc_signature, - symbol: str, + type_param_symbol_prefix: str, + function_param_symbol_prefix: str, ) -> Dict[str, docutils.nodes.Element]: """Cross-links parameter names in signature to parameter objects. @@ -232,16 +251,20 @@ def _add_parameter_links_to_signature( """ sig_param_nodes: Dict[str, docutils.nodes.Element] = {} - replacements = [] + type_param_symbols: dict[str, str] = {} + + replacements: list[tuple[docutils.nodes.Element, str]] = [] node_identifier_key = "sphinx_immaterial_param_name_identifier" def add_replacement( - name_node: docutils.nodes.Element, param_node: docutils.nodes.Element + name_node: docutils.nodes.Element, + param_node: docutils.nodes.Element, + param_symbol: str, ) -> docutils.nodes.Element: - replacements.append((name_node, param_node)) name = name_node.astext() - # Mark `name_node` so that it can be identified after the deep copy of its - # ancestor `param_node`. + replacements.append((name_node, param_symbol)) + # Temporarily mark `name_node` so that it can be identified after the + # deep copy of its ancestor `param_node`. name_node[node_identifier_key] = True param_node_copy = param_node.deepcopy() source, line = docutils.utils.get_source_line(param_node) @@ -249,51 +272,115 @@ def add_replacement( param_node_copy.line = line sig_param_nodes[name] = param_node_copy del name_node[node_identifier_key] + + # Locate the copy of `name_node` within `param_node_copy`. for name_node_copy in param_node_copy.findall(condition=type(name_node)): if name_node_copy.get(node_identifier_key): - return name_node_copy - raise ValueError("Could not locate name node within parameter") - - for desc_param_node in signode.findall(condition=sphinx.addnodes.desc_parameter): - for sig_param_node in desc_param_node: - if not isinstance(sig_param_node, sphinx.addnodes.desc_sig_name): - continue - new_sig_param_node = add_replacement(sig_param_node, desc_param_node) - new_sig_param_node["classes"].append("sig-name") - break + name_node_copy["classes"].append("sig-name") + break + else: + raise ValueError("Could not locate name node within parameter") + + def _collect_parameters( + nodetype: type[docutils.nodes.Element], symbol_prefix: str, is_type_param: bool + ): + for desc_param_node in signode.findall(condition=nodetype): + cur_symbol_prefix = desc_param_node.get( + "sphinx_immaterial_type_param_symbol_prefix", symbol_prefix + ) + for sig_param_node in desc_param_node: + if not isinstance(sig_param_node, sphinx.addnodes.desc_sig_name): + continue + name = sig_param_node.astext() + param_symbol = f"{cur_symbol_prefix}.{name}" + if is_type_param: + type_param_symbols[name] = param_symbol + add_replacement(sig_param_node, desc_param_node, param_symbol) + break + + _collect_parameters( + sphinx.addnodes.desc_parameter, + function_param_symbol_prefix, + is_type_param=False, + ) + _collect_parameters( + sphinx.addnodes.desc_type_parameter, + type_param_symbol_prefix, + is_type_param=True, + ) - for name_node, param_node in replacements: - name = name_node.astext() + for name_node, param_symbol in replacements: refnode = sphinx.addnodes.pending_xref( "", name_node.deepcopy(), refdomain="py", reftype="param", - reftarget=f"{symbol}.{name}", + reftarget=param_symbol, refwarn=False, ) refnode["implicit_sig_param_ref"] = True name_node.replace_self(refnode) + # Also cross-link references to type parameters in annotations. + for xref in signode.findall(condition=sphinx.addnodes.pending_xref): + if xref["refdomain"] == "py" and xref["reftype"] in ("class", "param"): + param_symbol = type_param_symbols.get(xref["reftarget"]) + if param_symbol is not None: + xref["reftarget"] = param_symbol + xref["refspecific"] = False + return sig_param_nodes +def _collate_parameter_symbols( + sig_param_nodes_for_signature: List[Dict[str, docutils.nodes.Element]], + symbols: list[str], + function_symbols: list[str], +) -> dict[str, tuple[Literal["parameter", "typeParameter"], list[str]]]: + param_symbols: dict[str, tuple[str, list[str]]] = {} + + for i, sig_param_nodes in enumerate(sig_param_nodes_for_signature): + for name, desc_param_node in sig_param_nodes.items(): + if isinstance(desc_param_node, sphinx.addnodes.desc_type_parameter): + if desc_param_node.get("sphinx_immaterial_type_param_symbol_prefix"): + continue + param_objtype = "typeParameter" + symbol = symbols[i] + else: + param_objtype = "parameter" + symbol = function_symbols[i] + existing = param_symbols.get(name) + param_symbol = f"{symbol}.{name}" + if existing is not None: + if existing[0] != param_objtype: + logger.warning( + "Parameter %r is both a type parameter and a function parameter", + name, + location=desc_param_node, + ) + continue + if param_symbol not in existing[1]: + existing[1].append(param_symbol) + else: + param_symbols[name] = (param_objtype, [param_symbol]) + return param_symbols + + def _add_parameter_documentation_ids( directive: sphinx.domains.python.PyObject, env: sphinx.environment.BuildEnvironment, obj_content: sphinx.addnodes.desc_content, sig_param_nodes_for_signature: List[Dict[str, docutils.nodes.Element]], - symbols: List[str], + symbols: list[str], + function_symbols: list[str], noindex: bool, -) -> None: +) -> set[str]: qualify_parameter_ids = "nonodeid" not in directive.options - param_options = object_description_options.get_object_description_options( - env, "py", "parameter" - ) - py = cast(sphinx.domains.python.PythonDomain, env.get_domain("py")) + noted_param_symbols: set[str] = set() + def cross_link_single_parameter( param_name: str, param_node: docutils.nodes.term ) -> None: @@ -306,15 +393,21 @@ def cross_link_single_parameter( # Identical declarations in more than one signature will only be # included once. unique_decls: Dict[str, Tuple[int, docutils.nodes.Element]] = {} - unique_symbols: Dict[Tuple[str, str], int] = {} + unique_symbols: Dict[str, bool] = {} + param_objtype = "parameter" for i, sig_param_nodes in enumerate(sig_param_nodes_for_signature): desc_param_node = sig_param_nodes.get(param_name) if desc_param_node is None: continue desc_param_node = cast(docutils.nodes.Element, desc_param_node) + if isinstance(desc_param_node, sphinx.addnodes.desc_type_parameter): + param_objtype = "typeParameter" + symbol = ( + symbols[i] if param_objtype == "typeParameter" else function_symbols[i] + ) decl_text = desc_param_node.astext().strip() unique_decls.setdefault(decl_text, (i, desc_param_node)) - unique_symbols.setdefault((decl_text, symbols[i]), i) + unique_symbols.setdefault(symbol, True) if not unique_decls: all_params = {} for sig_param_nodes in sig_param_nodes_for_signature: @@ -329,6 +422,10 @@ def cross_link_single_parameter( return if not noindex: + param_options = object_description_options.get_object_description_options( + env, "py", param_objtype + ) + synopsis: Optional[str] generate_synopses = param_options["generate_synopses"] if generate_synopses is not None: @@ -341,15 +438,9 @@ def cross_link_single_parameter( unqualified_param_id = f"p-{param_name}" - param_symbols = set() - # Set ids of the parameter node. - for symbol_i in unique_symbols.values(): - symbol = symbols[symbol_i] + for symbol in unique_symbols: param_symbol = f"{symbol}.{param_name}" - if param_symbol in param_symbols: - continue - param_symbols.add(param_symbol) if synopsis: py.data["synopses"][param_symbol] = synopsis @@ -362,7 +453,10 @@ def cross_link_single_parameter( else: node_id = unqualified_param_id - py.note_object(param_symbol, "parameter", node_id, location=param_node) + py.note_object( + param_symbol, param_objtype, node_id, location=param_node + ) + noted_param_symbols.add(param_symbol) if param_options["include_in_toc"]: toc_title = param_name @@ -422,15 +516,18 @@ def cross_link_single_parameter( if not param_name: continue cross_link_single_parameter(param_name, term) + return noted_param_symbols def _cross_link_parameters( directive: sphinx.domains.python.PyObject, app: sphinx.application.Sphinx, - signodes: List[sphinx.addnodes.desc_signature], + signodes: list[sphinx.addnodes.desc_signature], content: sphinx.addnodes.desc_content, - symbols: List[str], + symbols: list[str], + function_symbols: list[str], noindex: bool, + node_id: str, ) -> None: env = app.env assert isinstance(env, sphinx.environment.BuildEnvironment) @@ -443,23 +540,39 @@ def _cross_link_parameters( # replace the bare parameter name so that the parameter description shows # e.g. `x : int = 10` rather than just `x`. sig_param_nodes_for_signature = [] - for signode, symbol in zip(signodes, symbols): + for signode, symbol, function_symbol in zip(signodes, symbols, function_symbols): sig_param_nodes_for_signature.append( - _add_parameter_links_to_signature(env, signode, symbol) + _add_parameter_links_to_signature(env, signode, symbol, function_symbol) ) # Find all parameter descriptions in the object description body, and mark # them as the target for cross links to that parameter. Also substitute in # the parameter declaration for the bare parameter name, as described above. - _add_parameter_documentation_ids( + noted_param_symbols = _add_parameter_documentation_ids( directive=directive, env=env, obj_content=content, sig_param_nodes_for_signature=sig_param_nodes_for_signature, symbols=symbols, + function_symbols=function_symbols, noindex=noindex, ) + if not noindex: + py = cast(sphinx.domains.python.PythonDomain, env.get_domain("py")) + + param_symbols_by_name = _collate_parameter_symbols( + sig_param_nodes_for_signature, symbols, function_symbols + ) + + for name, (param_objtype, param_symbols) in param_symbols_by_name.items(): + for param_symbol in param_symbols: + if param_symbol in noted_param_symbols: + continue + py.note_object( + param_symbol, param_objtype, node_id, location=signodes[0] + ) + def _monkey_patch_python_domain_to_cross_link_parameters(): orig_after_content = PyObject.after_content @@ -471,10 +584,11 @@ def after_content(self: PyObject) -> None: ).parent signodes = obj_desc.children[:-1] - py = cast(PythonDomain, self.env.get_domain("py")) - noindex = "noindex" in self.options + node_ids = signodes[0].get("ids") + node_id = node_ids[0] if node_ids else "" + symbols = [] for signode in cast(List[docutils.nodes.Element], signodes): modname = signode["module"] @@ -485,7 +599,7 @@ def after_content(self: PyObject) -> None: if not symbols: return if self.objtype in ("class", "exception"): - # Any parameters are actually constructor parameters. To avoid + # Any function parameters are actually constructor parameters. To avoid # symbol name conflicts, assign object names under `__init__`. function_symbols = [f"{symbol}.__init__" for symbol in symbols] else: @@ -496,22 +610,69 @@ def after_content(self: PyObject) -> None: app=self.env.app, signodes=cast(List[sphinx.addnodes.desc_signature], signodes), content=getattr(self, "contentnode"), - symbols=function_symbols, + symbols=symbols, + function_symbols=function_symbols, noindex=noindex, + node_id=node_id, ) PyObject.after_content = after_content # type: ignore[assignment] +def _monkey_patch_python_domain_to_support_type_param_fields(): + """Adds support for type parameter fields in doc strings.""" + sphinx.domains.python.PyObject.doc_field_types.insert( + 0, + PyTypedField( + "type parameter", + label=_("Type Parameters"), + names=("tparam", "type parameter"), + typerolename="class", + typenames=("tparambound",), + can_collapse=True, + ), + ) + + +def _monkey_patch_napoleon_to_support_type_params(): + """Adds support for a `Type Parameters` section.""" + LOAD_CUSTOM_SECTIONS = "_load_custom_sections" + orig_load_custom_sections = getattr(GoogleDocstring, LOAD_CUSTOM_SECTIONS) + + def parse_type_parameters_section(self: GoogleDocstring, section: str) -> list[str]: + fields = self._consume_fields(multiple=True) + return self._format_docutils_params( + fields, field_role="tparam", type_role="tparambound" + ) + + def load_custom_sections(self: GoogleDocstring) -> None: + self._sections.setdefault( + "type parameters", + lambda section: parse_type_parameters_section(self, section), + ) + orig_load_custom_sections(self) + + setattr(GoogleDocstring, LOAD_CUSTOM_SECTIONS, load_custom_sections) + + +def _monkey_patch_python_domain_to_define_parameter_object_types(): + sphinx.domains.python.PythonDomain.object_types[ + "parameter" + ] = sphinx.domains.ObjType("parameter", "param") + + sphinx.domains.python.PythonDomain.object_types[ + "typeParameter" + ] = sphinx.domains.ObjType("type parameter", "param", "class") + + +_monkey_patch_python_domain_to_define_parameter_object_types() _monkey_patch_python_domain_to_cross_link_parameters() _monkey_patch_python_doc_fields() _monkey_patch_python_domain_to_store_func_in_ref_context() _monkey_patch_python_domain_to_resolve_params() _monkey_patch_python_domain_to_deprioritize_params_in_search() - -sphinx.domains.python.PythonDomain.object_types["parameter"] = sphinx.domains.ObjType( - "parameter", "param" -) +_monkey_patch_python_domain_to_support_type_param_fields() +_monkey_patch_napoleon_to_support_type_params() def setup(app: sphinx.application.Sphinx): diff --git a/sphinx_immaterial/apidoc/python/type_annotation_transforms.py b/sphinx_immaterial/apidoc/python/type_annotation_transforms.py index 5d2b6e229..5af966ba6 100644 --- a/sphinx_immaterial/apidoc/python/type_annotation_transforms.py +++ b/sphinx_immaterial/apidoc/python/type_annotation_transforms.py @@ -23,6 +23,8 @@ import sphinx.environment import sphinx.util.logging +from . import type_param_utils + # `ast.unparse` added in Python 3.9 if sys.version_info >= (3, 9): from ast import unparse as ast_unparse @@ -365,6 +367,22 @@ def type_to_xref( *args, suppress_prefix: bool = False, ) -> sphinx.addnodes.pending_xref: + if (type_param := type_param_utils.decode_type_param(target)) is not None: + refnode = sphinx.addnodes.pending_xref( + "", + docutils.nodes.Text(type_param.__name__), + refdomain="py", + reftype="param", + reftarget=type_param.__name__, + refspecific=True, + refexplicit=True, + refwarn=True, + ) + refnode["py:func"] = env.ref_context.get("py:func") + refnode["py:class"] = env.ref_context.get("py:class") + refnode["py:module"] = env.ref_context.get("py:module") + return refnode + if sphinx.version_info < (7, 2): # suppress_prefix may not have been used like a kwarg before v7.2.0 as # there was only 3 params for type_to_xref() prior to v7.2.0 @@ -397,10 +415,11 @@ def type_to_xref( sphinx.domains.python.type_to_xref = type_to_xref # type: ignore[assignment] -def setup(app: sphinx.application.Sphinx): - _monkey_patch_python_domain_to_transform_type_annotations() - _monkey_patch_python_domain_to_transform_xref_titles() +_monkey_patch_python_domain_to_transform_type_annotations() +_monkey_patch_python_domain_to_transform_xref_titles() + +def setup(app: sphinx.application.Sphinx): app.add_config_value( "python_type_aliases", default={}, diff --git a/sphinx_immaterial/apidoc/python/type_param_utils.py b/sphinx_immaterial/apidoc/python/type_param_utils.py new file mode 100644 index 000000000..f2ac78186 --- /dev/null +++ b/sphinx_immaterial/apidoc/python/type_param_utils.py @@ -0,0 +1,244 @@ +"""Utilities related to Python type parameter lists.""" + +import collections.abc +import sys +import re +import typing + +from sphinx.util.inspect import safe_getattr +import sphinx.util.typing + + +TYPE_VAR_ANNOTATION_PREFIX = "__SPHINX_IMMATERIAL_TYPE_VAR__" +"""Prefix used by monkey-patched `stringify_annotation` to indicate a TypeVar. + +This prefix is checked in a monkey-patched `type_to_xref` where it is converted +into a parameter reference. +""" + + +if sys.version_info >= (3, 10): + TypeParam = typing.Union[typing.TypeVar, typing.ParamSpec] +elif sys.version_info >= (3, 11): + TypeParam = typing.Union[typing.TypeVar, typing.TypeVarTuple, typing.ParamSpec] +else: + TypeParam = typing.TypeVar + + +def get_class_type_params(cls: type) -> tuple[TypeParam, ...]: + """Returns the ordered list of type parameters of a class.""" + + type_params = safe_getattr(cls, "__type_params__", ()) + if type_params: + return type_params + + origin = typing.get_origin(cls) + if origin is None: + origin = cls + if isinstance(origin, type) and not issubclass(origin, typing.Generic): + return () + + args = typing.get_args(cls) + if args: + return tuple(arg for arg in args if isinstance(arg, TypeParam)) + + bases = safe_getattr(cls, "__orig_bases__", ()) + if not bases: + bases = safe_getattr(cls, "__bases__", ()) + if not bases: + return () + + # First check for `typing.Generic`, since that takes precedence. + for base in bases: + if typing.get_origin(base) is typing.Generic: + return typing.get_args(base) + + params = {} + for base in bases: + cur_params = get_class_type_params(base) + for param in cur_params: + params.setdefault(param, True) + return tuple(params) + + +def stringify_type_params(type_params: collections.abc.Sequence[TypeParam]) -> str: + """Convert a type parameter list to its string representation. + + The string representation is suitable for embedding in a Python domain + signature. + """ + if not type_params: + return "" + parts = ["["] + for i, param in enumerate(type_params): + if i != 0: + parts.append(",") + if isinstance(param, typing.TypeVar): + parts.append(param.__name__) + parts.append("") + + if bound := param.__bound__: + parts.append(" : ") + parts.append(sphinx.util.typing.stringify_annotation(bound)) + if constraints := param.__constraints__: + parts.append(" : (") + for j, constraint in enumerate(constraints): + if j != 0: + parts.append(", ") + parts.append(sphinx.util.typing.stringify_annotation(constraint)) + parts.append(")") + parts.append("]") + return "".join(parts) + + +TypeParamSubstitutions = dict[str, str] + + +def substitute_type_params( + stringified_annotation: str, substitutions: typing.Optional[TypeParamSubstitutions] +) -> str: + if not substitutions: + return stringified_annotation + return _ENCODED_TYPE_PARAM_PATTERN.sub( + lambda m: substitutions.get(m[2], m[0]), + stringified_annotation, + ) + + +def _get_base_class_type_param_substitutions_impl( + cls: type, + substitutions: TypeParamSubstitutions, + base_classes: dict[type, TypeParamSubstitutions], +): + bases = safe_getattr(cls, "__orig_bases__", ()) + for base in bases: + base_origin = typing.get_origin(base) or base + if isinstance(base_origin, type) and not issubclass( + base_origin, typing.Generic + ): + continue + if base_origin is typing.Generic: + continue + params = get_class_type_params(base_origin) + args = typing.get_args(base) + + base_substitutions: TypeParamSubstitutions = {} + + for param, arg in zip(params, args): + s_arg = substitute_type_params( + sphinx.util.typing.stringify_annotation(arg), substitutions + ) + base_substitutions[param.__name__] = s_arg + + base_classes.setdefault(base_origin, base_substitutions) + + _get_base_class_type_param_substitutions_impl( + base_origin, base_substitutions, base_classes + ) + + +def get_base_class_type_param_substitutions( + cls: type, +) -> dict[type, TypeParamSubstitutions]: + base_classes: dict[type, TypeParamSubstitutions] = {} + _get_base_class_type_param_substitutions_impl(cls, {}, base_classes) + return base_classes + + +_ENCODE_TYPE_PARAM: dict[type[TypeParam], typing.Callable[[TypeParam], str]] = { + typing.TypeVar: lambda annotation: ( + TYPE_VAR_ANNOTATION_PREFIX + "V_" + annotation.__name__ + ), +} +_DECODE_TYPE_PARAM: dict[str, typing.Callable[[str], TypeParam]] = { + "V": typing.TypeVar, +} + +if sys.version_info >= (3, 10): + _ENCODE_TYPE_PARAM[typing.ParamSpec] = ( + lambda annotation: TYPE_VAR_ANNOTATION_PREFIX + "P_" + annotation.__name__ + ) + _DECODE_TYPE_PARAM["P"] = typing.ParamSpec +if sys.version_info >= (3, 11): + _ENCODE_TYPE_PARAM[typing.TypeVarTuple] = ( + lambda annotation: TYPE_VAR_ANNOTATION_PREFIX + "T_" + annotation.__name__ + ) + _DECODE_TYPE_PARAM["T"] = typing.TypeVarTuple + + +_ENCODED_TYPE_PARAM_PATTERN = re.compile( + r"\b" + + TYPE_VAR_ANNOTATION_PREFIX + + "([" + + "".join(_DECODE_TYPE_PARAM.keys()) + + r"])_(\w+)\b" +) + + +def decode_type_param(annotation: str) -> typing.Optional[TypeParam]: + m = _ENCODED_TYPE_PARAM_PATTERN.fullmatch(annotation) + if m is None: + return None + kind = m[1] + name = m[2] + decode = _DECODE_TYPE_PARAM[kind] + return decode(name) + + +def encode_type_param(annotation: TypeParam) -> str: + return _ENCODE_TYPE_PARAM[type(annotation)](annotation) + + +def get_type_params_from_signature(signature: str) -> dict[str, TypeParam]: + params = {} + for m in _ENCODED_TYPE_PARAM_PATTERN.finditer(signature): + name = m[2] + if name in params: + continue + kind = m[1] + decode = _DECODE_TYPE_PARAM[kind] + params[name] = decode(name) + return params + + +def _monkey_patch_stringify_annotation_to_support_type_params(): + """In order to properly resolve references to type parameters in signatures, + they need to be given the `py:param` role with a target of the unqualified + type name. + + Normally, when `sphinx.ext.autodoc` encounters a TypeVar within a type + annotation, it formats it as `.`. As this is + indistinguishable from any other class name, our monkey-patched + `type_to_xref` would have no way to know when to create a `py:param` + reference instead of the usual `py:class` reference. + + As a workaround, monkey-patch `stringify_annotation` to format `TypeVar` + annotations as `TYPE_VAR_ANNOTATION_PREFIX + type_name`. This special prefix + is then detected and stripped off by the monkey-patched `type_to_xref` + defined in type_annotation_transforms.py. + """ + orig = sphinx.util.typing.stringify_annotation + + def stringify_annotation( + annotation: typing.Any, + /, + mode="fully-qualified-except-typing", + ) -> str: + if (encode := _ENCODE_TYPE_PARAM.get(type(annotation))) is not None: + return encode(annotation) + return orig(annotation, mode=mode) + + for module in [ + "sphinx.util.typing", + "sphinx.ext.autodoc", + "sphinx.ext.autodoc.typehints", + "sphinx.util.inspect", + "sphinx.ext.napoleon.docstring", + ]: + m = sys.modules.get(module) + if m is None: + continue + setattr(m, "stringify_annotation", stringify_annotation) + + +_monkey_patch_stringify_annotation_to_support_type_params() diff --git a/tests/python_type_param_utils_test.py b/tests/python_type_param_utils_test.py new file mode 100644 index 000000000..7ad122a81 --- /dev/null +++ b/tests/python_type_param_utils_test.py @@ -0,0 +1,42 @@ +import typing + +from sphinx_immaterial.apidoc.python import type_param_utils +from sphinx_immaterial.apidoc.python import parameter_objects + + +T = typing.TypeVar("T") +U = typing.TypeVar("U") + + +class Foo(typing.Generic[T, U]): + pass + + +class Bar1(Foo): + pass + + +class Bar2(Foo, typing.Generic[U, T]): + pass + + +class Bar3(Foo[int, T]): + pass + + +def test_get_class_type_params(): + assert type_param_utils.get_class_type_params(int) == () + assert type_param_utils.get_class_type_params(Foo) == (T, U) + assert type_param_utils.get_class_type_params(Bar1) == (T, U) + assert type_param_utils.get_class_type_params(Bar2) == (U, T) + assert type_param_utils.get_class_type_params(Bar3) == (T,) + + +def test_get_base_class_type_param_substitutions(): + class Bar4(Bar3[tuple[U, int]]): + pass + + assert type_param_utils.get_base_class_type_param_substitutions(Bar4) == { + Foo: {"T": "int", "U": "tuple[__SPHINX_IMMATERIAL_TYPE_VAR__U, int]"}, + Bar3: {"T": "tuple[__SPHINX_IMMATERIAL_TYPE_VAR__U, int]"}, + }