Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: automatically load external frontends in load() #4285

Merged
merged 6 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions doc/source/developing/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,37 @@ In subsequent versions, we plan to include in yt a catalog of known extensions
and where to find them; this will put discoverability directly into the code
base.

Frontend as an extension
------------------------

Starting with version 4.2 of yt, any externally installed package that exports
:class:`~yt.data_objects.static_output.Dataset` subclass as an entrypoint in
``yt.frontends`` namespace in ``setup.py`` or ``pyproject.toml`` will be
automatically loaded and immediately available in :func:`~yt.loaders.load`.

To add an entrypoint in an external project's ``setup.py``:

.. code-block:: python

setup(
# ...,
entry_points={
"yt.frontends": [
"myFrontend = my_frontend.api.MyFrontendDataset",
"myOtherFrontend = my_frontend.api.MyOtherFrontendDataset",
]
}
)

or ``pyproject.toml``:

.. code-block:: toml

[project.entry-points."yt.frontends"]
myFrontend = "my_frontend.api:MyFrontendDataset"
myOtherFrontend = "my_frontend.api:MyOtherFrontendDataset"


Extension Template
------------------

Expand Down
2 changes: 1 addition & 1 deletion nose_unit.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ nologcapture=1
verbosity=2
where=yt
with-timer=1
ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py|test_set_zlim\.py|test_add_field\.py|test_glue\.py|test_geometries\.py|test_firefly\.py|test_callable_grids\.py)
ignore-files=(test_load_errors.py|test_load_sample.py|test_commons.py|test_ambiguous_fields.py|test_field_access_pytest.py|test_save.py|test_line_annotation_unit.py|test_eps_writer.py|test_registration.py|test_invalid_origin.py|test_outputs_pytest\.py|test_normal_plot_api\.py|test_load_archive\.py|test_stream_particles\.py|test_file_sanitizer\.py|test_version\.py|\test_on_demand_imports\.py|test_set_zlim\.py|test_add_field\.py|test_glue\.py|test_geometries\.py|test_firefly\.py|test_callable_grids\.py|test_external_frontends\.py)
exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF
1 change: 1 addition & 0 deletions tests/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ other_tests:
- "--ignore-file=test_geometries\\.py"
- "--ignore-file=test_firefly\\.py"
- "--ignore-file=test_callable_grids\\.py"
- "--ignore-file=test_external_frontends\\.py"
- "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF"
- "--exclude-test=yt.frontends.adaptahop.tests.test_outputs"
- "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles"
Expand Down
7 changes: 5 additions & 2 deletions yt/frontends/sdf/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,11 @@ def _is_valid(cls, filename, *args, **kwargs):
# Grab a whole 4k page.
line = next(hreq.iter_content(4096))
elif os.path.isfile(sdf_header):
with open(sdf_header, encoding="ISO-8859-1") as f:
line = f.read(10).strip()
try:
with open(sdf_header, encoding="ISO-8859-1") as f:
line = f.read(10).strip()
except PermissionError:
return False
else:
return False
return line.startswith("# SDF")
10 changes: 10 additions & 0 deletions yt/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import types
import warnings
from importlib.metadata import entry_points
from multiprocessing import Pipe, Process
from multiprocessing.connection import Connection
from pathlib import Path
Expand Down Expand Up @@ -97,6 +98,15 @@ def load(
if not fn.startswith("http"):
fn = str(lookup_on_disk_data(fn))

if sys.version_info >= (3, 10):
external_frontends = entry_points(group="yt.frontends")
else:
external_frontends = entry_points().get("yt.frontends", [])

# Ensure that external frontends are loaded
for entrypoint in external_frontends:
entrypoint.load()
neutrinoceros marked this conversation as resolved.
Show resolved Hide resolved

candidates = []
for cls in output_type_registry.values():
if cls._is_valid(fn, *args, **kwargs):
Expand Down
59 changes: 59 additions & 0 deletions yt/tests/test_external_frontends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import sys

import pytest

import yt
from yt.data_objects.static_output import Dataset
from yt.geometry.grid_geometry_handler import GridIndex
from yt.utilities.object_registries import output_type_registry


class MockEntryPoint:
@classmethod
def load(cls):
class MockHierarchy(GridIndex):
grid = None

class ExtDataset(Dataset):
_index_class = MockHierarchy

def _parse_parameter_file(self):
self.current_time = 1.0
self.cosmological_simulation = 0

def _set_code_unit_attributes(self):
self.length_unit = self.quan(1.0, "code_length")
self.mass_unit = self.quan(1.0, "code_mass")
self.time_unit = self.quan(1.0, "code_time")

@classmethod
def _is_valid(cls, filename, *args, **kwargs):
return filename.endswith("mock")


@pytest.fixture()
def mock_external_frontend(monkeypatch):
def mock_entry_points(group=None):
if sys.version_info >= (3, 10):
return [MockEntryPoint]
else:
return {"yt.frontends": [MockEntryPoint]}

monkeypatch.setattr(yt.loaders, "entry_points", mock_entry_points)
assert "ExtDataset" not in output_type_registry

yield

assert "ExtDataset" in output_type_registry
# teardown to avoid test pollution
output_type_registry.pop("ExtDataset")


@pytest.mark.usefixtures("mock_external_frontend")
def test_external_frontend(tmp_path):
test_file = tmp_path / "tmp.mock"
test_file.write_text("") # create the file
assert test_file.is_file()

ds = yt.load(test_file)
assert "ExtDataset" in ds.__class__.__name__