Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add assert_contains to print a more precise error message #1003

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions src/orion/algo/space/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,43 @@ def __setitem__(self, key, value):
)
super().__setitem__(key, value)

def assert_contains(self, trial):
"""Same as __contains__ but instead of returning true or false it will raise an exception
with the exact causes of the mismatch.

Raises
------
ValueError is the trial has parameters that are not contained by the space.

"""
if isinstance(trial, str):
return super().__contains__(trial)

flattened_params = flatten(trial.params)
keys = set(flattened_params.keys())
errors = []

for dim_name, dim in self.items():
if dim_name not in keys:
errors.append(f"{dim_name} is missing")
continue

value = flattened_params[dim_name]
if value not in dim:
errors.append(f"{value} does not belong to the dimension {dim}")

keys.remove(dim_name)

if len(errors) > 0:
errors = "\n - ".join(errors)
raise ValueError(f"Trial {trial.id} is not contained in space:\n{errors}")

if len(keys) != 0:
errors = "\n - ".join(keys)
raise ValueError(f"Trial {trial.id} has additional parameters:\n{errors}")

return True

def __contains__(self, key_or_trial):
"""Check whether `trial` is within the bounds of the space.
Or check if a name for a dimension is registered in this space.
Expand All @@ -1105,19 +1142,10 @@ def __contains__(self, key_or_trial):
If str, test if the string is a dimension part of the search space.
If a Trial, test if trial's hyperparameters fit the current search space.
"""
if isinstance(key_or_trial, str):
return super().__contains__(key_or_trial)

trial = key_or_trial
flattened_params = flatten(trial.params)
keys = set(flattened_params.keys())
for dim_name, dim in self.items():
if dim_name not in keys or flattened_params[dim_name] not in dim:
return False

keys.remove(dim_name)

return len(keys) == 0
try:
return self.assert_contains(key_or_trial)
except ValueError:
return False

def __repr__(self):
"""Represent as a string the space and the dimensions it contains."""
Expand Down
8 changes: 8 additions & 0 deletions tests/unittests/algo/test_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,3 +1004,11 @@ def test_precision(self):

with pytest.raises(TypeError):
space.register(Real("yolo4", "norm", 0.9, precision=0.6))


def test_issue_1018():
from orion.algo.space import Real

Real("a1", "uniform", -4.0, 0.0, default_value=-0.2)
Real("a2", "uniform", -4.0, 0.0, low=-4.0, high=0.0, default_value=-0.2)
Real("a3", "uniform", low=-4.0, high=0.0, default_value=-0.2)
27 changes: 24 additions & 3 deletions tests/unittests/core/worker/algo_wrappers/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,29 @@ def test_verify_trial(self, algo_wrapper: SpaceTransform[DumbAlgo], space: Space

assert algo_wrapper.space is space

with pytest.raises(ValueError, match="not contained in space:"):
with pytest.raises(ValueError, match="yolo is missing"):
invalid_trial = Trial(
params=[
dict(name="yolo2", value=0, type="real"),
dict(name="yolo3", value=3.5, type="real"),
],
status="new",
)
algo_wrapper._verify_trial(invalid_trial)

with pytest.raises(ValueError, match="has additional parameters"):
invalid_trial = Trial(
params=[
dict(name="yolo", value=("asdfa", 2), type="real"),
dict(name="yolo2", value=0, type="real"),
dict(name="yolo3", value=3.5, type="real"),
dict(name="yolo4", value=0, type="real"),
],
status="new",
)
algo_wrapper._verify_trial(invalid_trial)

with pytest.raises(ValueError, match="does not belong to the dimension"):
invalid_trial = format_trials.tuple_to_trial((("asdfa", 2), 10, 3.5), space)
algo_wrapper._verify_trial(invalid_trial)

Expand All @@ -59,8 +81,7 @@ def test_verify_trial(self, algo_wrapper: SpaceTransform[DumbAlgo], space: Space

# transform point
ttrial = tspace.transform(trial)
# TODO: /~https://github.com/Epistimio/orion/issues/804
assert ttrial in tspace
tspace.assert_contains(ttrial)

# Transformed point is not in original space
with pytest.raises(ValueError, match="not contained in space:"):
Expand Down