diff --git a/src/orion/algo/space/__init__.py b/src/orion/algo/space/__init__.py index eafb9794a..ffdb496ee 100644 --- a/src/orion/algo/space/__init__.py +++ b/src/orion/algo/space/__init__.py @@ -32,6 +32,8 @@ import copy import logging import numbers +from dataclasses import dataclass, field +from distutils.log import error from functools import singledispatch from typing import Any, Generic, TypeVar @@ -1095,6 +1097,44 @@ def __setitem__(self, key, value): ) super().__setitem__(key, value) + def assert_contains(self, trial): + """Same as __contains__ but instead of return true or false it will raise an exception + with the exact causes of the mismatch. + + Raises + ------ + ValueError if the trial has parameters that are not contained by the space. + + """ + if isinstance(trial, str): + if not super().__contains__(trial): + raise ValueError("{trial} does not belong to the dimension") + return + + flattened_params = flatten(trial.params) + keys = set(flattened_params.keys()) + errors = [] + + for dim_name, dim in self.items(): + if dim_name not in keys: + errors.append(f"{dim_name} is missing") + continue + + value = flattened_params[dim_name] + if value not in dim: + errors.append(f"{value} does not belong to the dimension {dim}") + + keys.remove(dim_name) + + if len(errors) > 0: + raise ValueError(f"Trial {trial.id} is not contained in space:\n{errors}") + + if len(keys) != 0: + errors = "\n - ".join(keys) + raise ValueError(f"Trial {trial.id} has additional parameters:\n{errors}") + + return True + def __contains__(self, key_or_trial): """Check whether `trial` is within the bounds of the space. Or check if a name for a dimension is registered in this space. @@ -1105,19 +1145,12 @@ def __contains__(self, key_or_trial): If str, test if the string is a dimension part of the search space. If a Trial, test if trial's hyperparameters fit the current search space. """ - if isinstance(key_or_trial, str): - return super().__contains__(key_or_trial) - - trial = key_or_trial - flattened_params = flatten(trial.params) - keys = set(flattened_params.keys()) - for dim_name, dim in self.items(): - if dim_name not in keys or flattened_params[dim_name] not in dim: - return False - - keys.remove(dim_name) - return len(keys) == 0 + try: + self.assert_contains(key_or_trial) + return True + except ValueError: + return False def __repr__(self): """Represent as a string the space and the dimensions it contains.""" diff --git a/src/orion/core/worker/algo_wrappers/space_transform.py b/src/orion/core/worker/algo_wrappers/space_transform.py index eb145fb94..208d13e0d 100644 --- a/src/orion/core/worker/algo_wrappers/space_transform.py +++ b/src/orion/core/worker/algo_wrappers/space_transform.py @@ -76,8 +76,4 @@ def reverse_transform(self, trial: Trial) -> Trial: def _verify_trial(self, trial: Trial, space: Space | None = None) -> None: space = space or self.space - if trial not in space: - raise ValueError( - f"Trial {trial.id} not contained in space:" - f"\nParams: {trial.params}\nSpace: {space}" - ) + space.assert_contains(trial) diff --git a/src/orion/core/worker/transformer.py b/src/orion/core/worker/transformer.py index b6b5fe137..885d6477e 100644 --- a/src/orion/core/worker/transformer.py +++ b/src/orion/core/worker/transformer.py @@ -434,7 +434,7 @@ def __init__(self, categories): self.categories = categories map_dict = {cat: i for i, cat in enumerate(categories)} self._map = numpy.vectorize(lambda x: map_dict[x], otypes="i") - self._imap = numpy.vectorize(lambda x: categories[x], otypes=[numpy.object]) + self._imap = numpy.vectorize(lambda x: categories[x], otypes=[object]) def __deepcopy__(self, memo): """Make a deepcopy""" @@ -866,6 +866,19 @@ def sample(self, n_samples=1, seed=None): trials = self.original.sample(n_samples=n_samples, seed=seed) return [self.reshape(trial) for trial in trials] + def assert_contains(self, trial): + """Check if the trial or key is contained inside the space, if not an exception is raised + + Raises + ------ + TypeError when a dimension is not compatible with the space + + """ + if isinstance(trial, str): + super().assert_contains(trial) + + return self.original.assert_contains(self.restore_shape(trial)) + def __contains__(self, key_or_trial): """Check whether `trial` is within the bounds of the space. Or check if a name for a dimension is registered in this space. @@ -877,10 +890,11 @@ def __contains__(self, key_or_trial): If a Trial, test if trial's hyperparameters fit the current search space. """ - if isinstance(key_or_trial, str): - return super().__contains__(key_or_trial) - - return self.restore_shape(key_or_trial) in self.original + try: + self.assert_contains(key_or_trial) + return True + except ValueError: + return False @property def cardinality(self): diff --git a/tests/unittests/algo/test_space.py b/tests/unittests/algo/test_space.py index 4e32fa4ce..f3961670e 100644 --- a/tests/unittests/algo/test_space.py +++ b/tests/unittests/algo/test_space.py @@ -624,7 +624,7 @@ def test_cast_list_multidim(self): categories[0] = "asdfa" categories[2] = "lalala" dim = Categorical("yolo", categories, shape=2) - sample = ["asdfa", "1"] # np.array(['asdfa', '1'], dtype=np.object) + sample = ["asdfa", "1"] # np.array(['asdfa', '1'], dtype=object) assert dim.cast(sample) == ["asdfa", 1] def test_cast_array_multidim(self): @@ -633,14 +633,14 @@ def test_cast_array_multidim(self): categories[0] = "asdfa" categories[2] = "lalala" dim = Categorical("yolo", categories, shape=2) - sample = np.array(["asdfa", "1"], dtype=np.object) - assert np.all(dim.cast(sample) == np.array(["asdfa", 1], dtype=np.object)) + sample = np.array(["asdfa", "1"], dtype=object) + assert np.all(dim.cast(sample) == np.array(["asdfa", 1], dtype=object)) def test_cast_bad_category(self): """Make sure array are cast to int and returned as array of values""" categories = list(range(10)) dim = Categorical("yolo", categories, shape=2) - sample = np.array(["asdfa", "1"], dtype=np.object) + sample = np.array(["asdfa", "1"], dtype=object) with pytest.raises(ValueError) as exc: dim.cast(sample) assert "Invalid category: asdfa" in str(exc.value) diff --git a/tests/unittests/core/worker/algo_wrappers/test_transform.py b/tests/unittests/core/worker/algo_wrappers/test_transform.py index 830dfc1e2..63fde57ea 100644 --- a/tests/unittests/core/worker/algo_wrappers/test_transform.py +++ b/tests/unittests/core/worker/algo_wrappers/test_transform.py @@ -48,7 +48,17 @@ def test_verify_trial(self, algo_wrapper: SpaceTransform[DumbAlgo], space: Space assert algo_wrapper.space is space - with pytest.raises(ValueError, match="not contained in space:"): + with pytest.raises(ValueError, match="yolo is missing"): + invalid_trial = Trial( + params=[ + dict(name="yolo2", value=0, type="real"), + dict(name="yolo3", value=3.5, type="real"), + ], + status="new", + ) + algo_wrapper._verify_trial(invalid_trial) + + with pytest.raises(ValueError, match="does not belong to the dimension"): invalid_trial = format_trials.tuple_to_trial((("asdfa", 2), 10, 3.5), space) algo_wrapper._verify_trial(invalid_trial) @@ -59,7 +69,6 @@ def test_verify_trial(self, algo_wrapper: SpaceTransform[DumbAlgo], space: Space # transform point ttrial = tspace.transform(trial) - # TODO: /~https://github.com/Epistimio/orion/issues/804 assert ttrial in tspace # Transformed point is not in original space